IFMedTechdemo commited on
Commit
83140b5
·
verified ·
1 Parent(s): 6f82d20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -61
app.py CHANGED
@@ -16,6 +16,9 @@ from transformers import (
16
  TextIteratorStreamer,
17
  )
18
 
 
 
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  # Choose best attention implementation based on device
@@ -43,6 +46,10 @@ processor = LightOnOCRProcessor.from_pretrained(
43
  )
44
  print("Model loaded successfully!")
45
 
 
 
 
 
46
 
47
  def render_pdf_page(page, max_resolution=1540, scale=2.77):
48
  """Render a PDF page to PIL Image."""
@@ -69,35 +76,32 @@ def process_pdf(pdf_path, page_num=1):
69
 
70
  def clean_output_text(text):
71
  """Remove chat template artifacts from output."""
72
- # Remove common chat template markers
73
  markers_to_remove = ["system", "user", "assistant"]
74
-
75
- # Split by lines and filter
76
  lines = text.split('\n')
77
  cleaned_lines = []
78
-
79
  for line in lines:
80
  stripped = line.strip()
81
  # Skip lines that are just template markers
82
  if stripped.lower() not in markers_to_remove:
83
  cleaned_lines.append(line)
84
-
85
- # Join back and strip leading/trailing whitespace
86
  cleaned = '\n'.join(cleaned_lines).strip()
87
-
88
- # Alternative approach: if there's an "assistant" marker, take everything after it
89
  if "assistant" in text.lower():
90
  parts = text.split("assistant", 1)
91
  if len(parts) > 1:
92
  cleaned = parts[1].strip()
93
-
94
  return cleaned
95
 
 
 
 
 
 
 
 
96
 
97
  @spaces.GPU
98
  def extract_text_from_image(image, temperature=0.2, stream=False):
99
  """Extract text from image using LightOnOCR model."""
100
- # Prepare the chat format
101
  chat = [
102
  {
103
  "role": "user",
@@ -106,8 +110,6 @@ def extract_text_from_image(image, temperature=0.2, stream=False):
106
  ],
107
  }
108
  ]
109
-
110
- # Apply chat template and tokenize
111
  inputs = processor.apply_chat_template(
112
  chat,
113
  add_generation_prompt=True,
@@ -115,15 +117,12 @@ def extract_text_from_image(image, temperature=0.2, stream=False):
115
  return_dict=True,
116
  return_tensors="pt"
117
  )
118
-
119
- # Move inputs to device AND convert to the correct dtype
120
  inputs = {
121
  k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
122
  else v.to(device) if isinstance(v, torch.Tensor)
123
  else v
124
  for k, v in inputs.items()
125
  }
126
-
127
  generation_kwargs = dict(
128
  **inputs,
129
  max_new_tokens=2048,
@@ -131,54 +130,38 @@ def extract_text_from_image(image, temperature=0.2, stream=False):
131
  use_cache=True,
132
  do_sample=temperature > 0,
133
  )
134
-
135
  if stream:
136
- # Setup streamer for streaming generation
137
  streamer = TextIteratorStreamer(
138
  processor.tokenizer,
139
  skip_prompt=True,
140
  skip_special_tokens=True
141
  )
142
  generation_kwargs["streamer"] = streamer
143
-
144
- # Run generation in a separate thread
145
  thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
146
  thread.start()
147
-
148
- # Yield chunks as they arrive
149
  full_text = ""
150
  for new_text in streamer:
151
  full_text += new_text
152
- # Clean the accumulated text
153
  cleaned_text = clean_output_text(full_text)
154
  yield cleaned_text
155
-
156
  thread.join()
157
  else:
158
  # Non-streaming generation
159
  with torch.no_grad():
160
  outputs = model.generate(**generation_kwargs)
161
-
162
- # Decode the output
163
  output_text = processor.decode(outputs[0], skip_special_tokens=True)
164
-
165
- # Clean the output
166
  cleaned_text = clean_output_text(output_text)
167
-
168
  yield cleaned_text
169
 
170
-
171
  def process_input(file_input, temperature, page_num, enable_streaming):
172
- """Process uploaded file (image or PDF) and extract text with optional streaming."""
173
  if file_input is None:
174
  yield "Please upload an image or PDF first.", "", "", None, gr.update()
175
  return
176
-
177
  image_to_process = None
178
  page_info = ""
179
-
180
  file_path = file_input if isinstance(file_input, str) else file_input.name
181
-
182
  # Handle PDF files
183
  if file_path.lower().endswith('.pdf'):
184
  try:
@@ -195,24 +178,20 @@ def process_input(file_input, temperature, page_num, enable_streaming):
195
  except Exception as e:
196
  yield f"Error opening image: {str(e)}", "", "", None, gr.update()
197
  return
198
-
199
  try:
200
- # Extract text using LightOnOCR with optional streaming
201
  for extracted_text in extract_text_from_image(image_to_process, temperature, stream=enable_streaming):
202
- yield extracted_text, extracted_text, page_info, image_to_process, gr.update()
203
-
 
204
  except Exception as e:
205
  error_msg = f"Error during text extraction: {str(e)}"
206
  yield error_msg, error_msg, page_info, image_to_process, gr.update()
207
 
208
-
209
  def update_slider(file_input):
210
  """Update page slider based on PDF page count."""
211
  if file_input is None:
212
  return gr.update(maximum=20, value=1)
213
-
214
  file_path = file_input if isinstance(file_input, str) else file_input.name
215
-
216
  if file_path.lower().endswith('.pdf'):
217
  try:
218
  pdf = pdfium.PdfDocument(file_path)
@@ -224,25 +203,23 @@ def update_slider(file_input):
224
  else:
225
  return gr.update(maximum=1, value=1)
226
 
227
-
228
- # Create Gradio interface
229
- with gr.Blocks(title="📖 Image/PDF OCR with LightOnOCR", theme=gr.themes.Soft()) as demo:
230
  gr.Markdown(f"""
231
- # 📖 Image/PDF to Text Extraction with LightOnOCR
232
 
233
  **💡 How to use:**
234
  1. Upload an image or PDF
235
- 2. For PDFs: select which page to extract (1-20)
236
  3. Adjust temperature if needed
237
- 4. Click "Extract Text"
238
 
239
- **Note:** The Markdown rendering for tables may not always be perfect. Check the raw output for complex tables!
240
 
241
  **Model:** LightOnOCR-1B-1025 by LightOn AI
242
  **Device:** {device.upper()}
243
  **Attention:** {attn_implementation}
244
  """)
245
-
246
  with gr.Row():
247
  with gr.Column(scale=1):
248
  file_input = gr.File(
@@ -282,43 +259,37 @@ with gr.Blocks(title="📖 Image/PDF OCR with LightOnOCR", theme=gr.themes.Soft(
282
  value=True,
283
  info="Show text progressively as it's generated"
284
  )
285
- submit_btn = gr.Button("Extract Text", variant="primary")
286
  clear_btn = gr.Button("Clear", variant="secondary")
287
-
288
  with gr.Column(scale=2):
289
  output_text = gr.Markdown(
290
- label="📄 Extracted Text (Rendered)",
291
- value="*Extracted text will appear here...*"
292
  )
293
-
294
  with gr.Row():
295
  with gr.Column():
296
  raw_output = gr.Textbox(
297
- label="Raw Markdown Output",
298
- placeholder="Raw text will appear here...",
299
  lines=20,
300
  max_lines=30,
301
  show_copy_button=True
302
  )
303
-
304
  # Event handlers
305
  submit_btn.click(
306
  fn=process_input,
307
  inputs=[file_input, temperature, num_pages, enable_streaming],
308
  outputs=[output_text, raw_output, page_info, rendered_image, num_pages]
309
  )
310
-
311
  file_input.change(
312
  fn=update_slider,
313
  inputs=[file_input],
314
  outputs=[num_pages]
315
  )
316
-
317
  clear_btn.click(
318
- fn=lambda: (None, "*Extracted text will appear here...*", "", "", None, 1),
319
  outputs=[file_input, output_text, raw_output, page_info, rendered_image, num_pages]
320
  )
321
 
322
-
323
  if __name__ == "__main__":
324
- demo.launch()
 
16
  TextIteratorStreamer,
17
  )
18
 
19
+ # ---- CLINICAL NER IMPORTS ----
20
+ import spacy
21
+
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
  # Choose best attention implementation based on device
 
46
  )
47
  print("Model loaded successfully!")
48
 
49
+ # ---- LOAD CLINICAL NER MODEL (BC5CDR) ----
50
+ print("Loading clinical NER model (bc5cdr)...")
51
+ nlp_ner = spacy.load("en_ner_bc5cdr_md")
52
+ print("Clinical NER loaded.")
53
 
54
  def render_pdf_page(page, max_resolution=1540, scale=2.77):
55
  """Render a PDF page to PIL Image."""
 
76
 
77
  def clean_output_text(text):
78
  """Remove chat template artifacts from output."""
 
79
  markers_to_remove = ["system", "user", "assistant"]
 
 
80
  lines = text.split('\n')
81
  cleaned_lines = []
 
82
  for line in lines:
83
  stripped = line.strip()
84
  # Skip lines that are just template markers
85
  if stripped.lower() not in markers_to_remove:
86
  cleaned_lines.append(line)
 
 
87
  cleaned = '\n'.join(cleaned_lines).strip()
 
 
88
  if "assistant" in text.lower():
89
  parts = text.split("assistant", 1)
90
  if len(parts) > 1:
91
  cleaned = parts[1].strip()
 
92
  return cleaned
93
 
94
+ def extract_medication_names(text):
95
+ """Extract medication names using clinical NER (spacy: bc5cdr CHEMICAL)."""
96
+ doc = nlp_ner(text)
97
+ meds = [ent.text for ent in doc.ents if ent.label_ == "CHEMICAL"]
98
+ meds_unique = list(dict.fromkeys(meds))
99
+ return meds_unique
100
+
101
 
102
  @spaces.GPU
103
  def extract_text_from_image(image, temperature=0.2, stream=False):
104
  """Extract text from image using LightOnOCR model."""
 
105
  chat = [
106
  {
107
  "role": "user",
 
110
  ],
111
  }
112
  ]
 
 
113
  inputs = processor.apply_chat_template(
114
  chat,
115
  add_generation_prompt=True,
 
117
  return_dict=True,
118
  return_tensors="pt"
119
  )
 
 
120
  inputs = {
121
  k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
122
  else v.to(device) if isinstance(v, torch.Tensor)
123
  else v
124
  for k, v in inputs.items()
125
  }
 
126
  generation_kwargs = dict(
127
  **inputs,
128
  max_new_tokens=2048,
 
130
  use_cache=True,
131
  do_sample=temperature > 0,
132
  )
 
133
  if stream:
134
+ # Streaming generation
135
  streamer = TextIteratorStreamer(
136
  processor.tokenizer,
137
  skip_prompt=True,
138
  skip_special_tokens=True
139
  )
140
  generation_kwargs["streamer"] = streamer
 
 
141
  thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
142
  thread.start()
 
 
143
  full_text = ""
144
  for new_text in streamer:
145
  full_text += new_text
 
146
  cleaned_text = clean_output_text(full_text)
147
  yield cleaned_text
 
148
  thread.join()
149
  else:
150
  # Non-streaming generation
151
  with torch.no_grad():
152
  outputs = model.generate(**generation_kwargs)
 
 
153
  output_text = processor.decode(outputs[0], skip_special_tokens=True)
 
 
154
  cleaned_text = clean_output_text(output_text)
 
155
  yield cleaned_text
156
 
 
157
  def process_input(file_input, temperature, page_num, enable_streaming):
158
+ """Process uploaded file (image or PDF) and extract medication names via OCR+NER."""
159
  if file_input is None:
160
  yield "Please upload an image or PDF first.", "", "", None, gr.update()
161
  return
 
162
  image_to_process = None
163
  page_info = ""
 
164
  file_path = file_input if isinstance(file_input, str) else file_input.name
 
165
  # Handle PDF files
166
  if file_path.lower().endswith('.pdf'):
167
  try:
 
178
  except Exception as e:
179
  yield f"Error opening image: {str(e)}", "", "", None, gr.update()
180
  return
 
181
  try:
 
182
  for extracted_text in extract_text_from_image(image_to_process, temperature, stream=enable_streaming):
183
+ meds = extract_medication_names(extracted_text)
184
+ meds_str = "\n".join(meds) if meds else "No medications found."
185
+ yield meds_str, meds_str, page_info, image_to_process, gr.update()
186
  except Exception as e:
187
  error_msg = f"Error during text extraction: {str(e)}"
188
  yield error_msg, error_msg, page_info, image_to_process, gr.update()
189
 
 
190
  def update_slider(file_input):
191
  """Update page slider based on PDF page count."""
192
  if file_input is None:
193
  return gr.update(maximum=20, value=1)
 
194
  file_path = file_input if isinstance(file_input, str) else file_input.name
 
195
  if file_path.lower().endswith('.pdf'):
196
  try:
197
  pdf = pdfium.PdfDocument(file_path)
 
203
  else:
204
  return gr.update(maximum=1, value=1)
205
 
206
+ # ----- GRADIO UI -----
207
+ with gr.Blocks(title="📖 Image/PDF OCR + Clinical NER", theme=gr.themes.Soft()) as demo:
 
208
  gr.Markdown(f"""
209
+ # 📖 Medication Extraction from Image/PDF with LightOnOCR + Clinical NER
210
 
211
  **💡 How to use:**
212
  1. Upload an image or PDF
213
+ 2. For PDFs: select which page to extract
214
  3. Adjust temperature if needed
215
+ 4. Click "Extract Medications"
216
 
217
+ **Output:** Only medication names found in text (via NER)
218
 
219
  **Model:** LightOnOCR-1B-1025 by LightOn AI
220
  **Device:** {device.upper()}
221
  **Attention:** {attn_implementation}
222
  """)
 
223
  with gr.Row():
224
  with gr.Column(scale=1):
225
  file_input = gr.File(
 
259
  value=True,
260
  info="Show text progressively as it's generated"
261
  )
262
+ submit_btn = gr.Button("Extract Medications", variant="primary")
263
  clear_btn = gr.Button("Clear", variant="secondary")
 
264
  with gr.Column(scale=2):
265
  output_text = gr.Markdown(
266
+ label="🩺 Extracted Medication Names",
267
+ value="*Medication names will appear here...*"
268
  )
 
269
  with gr.Row():
270
  with gr.Column():
271
  raw_output = gr.Textbox(
272
+ label="Extracted Medication Names (Raw)",
273
+ placeholder="Medication list will appear here...",
274
  lines=20,
275
  max_lines=30,
276
  show_copy_button=True
277
  )
 
278
  # Event handlers
279
  submit_btn.click(
280
  fn=process_input,
281
  inputs=[file_input, temperature, num_pages, enable_streaming],
282
  outputs=[output_text, raw_output, page_info, rendered_image, num_pages]
283
  )
 
284
  file_input.change(
285
  fn=update_slider,
286
  inputs=[file_input],
287
  outputs=[num_pages]
288
  )
 
289
  clear_btn.click(
290
+ fn=lambda: (None, "*Medication names will appear here...*", "", "", None, 1),
291
  outputs=[file_input, output_text, raw_output, page_info, rendered_image, num_pages]
292
  )
293
 
 
294
  if __name__ == "__main__":
295
+ demo.launch()