IFMedTechdemo commited on
Commit
95d2834
·
verified ·
1 Parent(s): 60d9cea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -105
app.py CHANGED
@@ -8,27 +8,23 @@ import torch
8
 
9
  import gradio as gr
10
  from PIL import Image
11
- from io import BytesIO
 
12
  import pypdfium2 as pdfium
13
  from transformers import (
14
  LightOnOCRForConditionalGeneration,
15
  LightOnOCRProcessor,
16
  )
17
-
18
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
-
22
  if device == "cuda":
23
  attn_implementation = "sdpa"
24
  dtype = torch.bfloat16
25
- print("Using sdpa for GPU")
26
  else:
27
  attn_implementation = "eager"
28
  dtype = torch.float32
29
- print("Using eager attention for CPU")
30
 
31
- print(f"Loading LightOnOCR model on {device} with {attn_implementation} attention...")
32
  ocr_model = LightOnOCRForConditionalGeneration.from_pretrained(
33
  "lightonai/LightOnOCR-1B-1025",
34
  attn_implementation=attn_implementation,
@@ -40,10 +36,7 @@ processor = LightOnOCRProcessor.from_pretrained(
40
  "lightonai/LightOnOCR-1B-1025",
41
  trust_remote_code=True,
42
  )
43
- print("LightOnOCR model loaded successfully!")
44
 
45
- # -------- Clinical NER models (load ONCE) --------
46
- print("Loading clinical NER model...")
47
  ner_tokenizer = AutoTokenizer.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
48
  ner_model = AutoModelForTokenClassification.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
49
  ner_pipeline = pipeline(
@@ -52,11 +45,8 @@ ner_pipeline = pipeline(
52
  tokenizer=ner_tokenizer,
53
  aggregation_strategy="simple",
54
  )
55
- print("Clinical NER model loaded successfully!")
56
-
57
 
58
  def render_pdf_page(page, max_resolution=1540, scale=2.77):
59
- """Render a PDF page to PIL Image."""
60
  width, height = page.get_size()
61
  pixel_width = width * scale
62
  pixel_height = height * scale
@@ -64,61 +54,58 @@ def render_pdf_page(page, max_resolution=1540, scale=2.77):
64
  target_scale = scale * resize_factor
65
  return page.render(scale=target_scale, rev_byteorder=True).to_pil()
66
 
67
-
68
  def process_pdf(pdf_path, page_num=1):
69
- """Extract a specific page from PDF."""
70
  pdf = pdfium.PdfDocument(pdf_path)
71
  total_pages = len(pdf)
72
  page_idx = min(max(int(page_num) - 1, 0), total_pages - 1)
73
-
74
  page = pdf[page_idx]
75
  img = render_pdf_page(page)
76
-
77
  pdf.close()
78
  return img, total_pages, page_idx + 1
79
 
80
-
81
  def clean_output_text(text):
82
- """Remove chat template artifacts from output."""
83
- # Remove common chat template markers
84
  markers_to_remove = ["system", "user", "assistant"]
85
-
86
- # Split by lines and filter
87
  lines = text.split('\n')
88
  cleaned_lines = []
89
-
90
  for line in lines:
91
  stripped = line.strip()
92
- # Skip lines that are just template markers
93
  if stripped.lower() not in markers_to_remove:
94
  cleaned_lines.append(line)
95
-
96
- # Join back and strip leading/trailing whitespace
97
  cleaned = '\n'.join(cleaned_lines).strip()
98
-
99
- # Alternative approach: if there's an "assistant" marker, take everything after it
100
  if "assistant" in text.lower():
101
  parts = text.split("assistant", 1)
102
  if len(parts) > 1:
103
  cleaned = parts[1].strip()
104
-
105
  return cleaned
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  @spaces.GPU
109
  def extract_text_from_image(image, temperature=0.2):
110
- """Extract text from image using LightOnOCR model, and run clinical NER."""
111
- # Prepare the chat format
112
  chat = [
113
  {
114
  "role": "user",
115
  "content": [
116
- {"type": "image", "url": image}, # adjust to {"type": "image", "image": image} if LightOnOCR expects that
117
  ],
118
  }
119
  ]
120
-
121
- # Tokenize
122
  inputs = processor.apply_chat_template(
123
  chat,
124
  add_generation_prompt=True,
@@ -126,7 +113,6 @@ def extract_text_from_image(image, temperature=0.2):
126
  return_dict=True,
127
  return_tensors="pt",
128
  )
129
-
130
  # Move inputs to device
131
  inputs = {
132
  k: (
@@ -138,7 +124,6 @@ def extract_text_from_image(image, temperature=0.2):
138
  )
139
  for k, v in inputs.items()
140
  }
141
-
142
  generation_kwargs = dict(
143
  **inputs,
144
  max_new_tokens=2048,
@@ -146,19 +131,12 @@ def extract_text_from_image(image, temperature=0.2):
146
  use_cache=True,
147
  do_sample=temperature > 0,
148
  )
149
-
150
- # Non-streaming generation
151
  with torch.no_grad():
152
  outputs = ocr_model.generate(**generation_kwargs)
153
 
154
  output_text = processor.decode(outputs[0], skip_special_tokens=True)
155
- cleaned_text = clean_output_text(output_text)
156
-
157
- print("\n this is cleaned_text",cleaned_text )
158
-
159
- # Clinical NER on the full cleaned text
160
- entities = ner_pipeline(cleaned_text)
161
- print("\n this is entity",entities)
162
  medications = []
163
  for ent in entities:
164
  if ent["entity_group"] == "treatment":
@@ -167,28 +145,19 @@ def extract_text_from_image(image, temperature=0.2):
167
  medications[-1] += word[2:]
168
  else:
169
  medications.append(word)
170
-
171
  medications_str = ", ".join(set(medications)) if medications else "None detected"
172
-
173
- yield cleaned_text, medications_str
174
-
175
-
176
-
177
 
178
  def process_input(file_input, temperature, page_num):
179
- """Process uploaded file (image or PDF) and extract text with optional streaming."""
180
  if file_input is None:
181
- # 6 outputs: [output_text, medications_output, raw_output, page_info, rendered_image, num_pages]
182
- yield "Please upload an image or PDF first.", "", "", "", None, 1
183
  return
184
 
185
  image_to_process = None
186
  page_info = ""
187
  slider_value = page_num
188
-
189
  file_path = file_input if isinstance(file_input, str) else file_input.name
190
 
191
- # Handle PDF files
192
  if file_path.lower().endswith(".pdf"):
193
  try:
194
  image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
@@ -199,7 +168,6 @@ def process_input(file_input, temperature, page_num):
199
  yield msg, "", msg, "", None, slider_value
200
  return
201
  else:
202
- # Handle image files
203
  try:
204
  image_to_process = Image.open(file_path)
205
  page_info = "Processing image"
@@ -209,29 +177,18 @@ def process_input(file_input, temperature, page_num):
209
  return
210
 
211
  try:
212
- # Extract text using LightOnOCR with optional streaming
213
- for extracted_text, medications in extract_text_from_image(
214
  image_to_process, temperature
215
  ):
216
- raw_md = extracted_text # or you can keep a different raw version
217
- # 6 outputs: markdown_text, medications, raw_output, page_info, image, slider
218
- yield extracted_text, medications, raw_md, page_info, image_to_process, gr.update(
219
- value=slider_value
220
- )
221
-
222
  except Exception as e:
223
  error_msg = f"Error during text extraction: {str(e)}"
224
- # 6 outputs
225
- yield error_msg, "", error_msg, page_info, image_to_process, gr.update(value=slider_value)
226
-
227
 
228
  def update_slider(file_input):
229
- """Update page slider based on PDF page count."""
230
  if file_input is None:
231
  return gr.update(maximum=20, value=1)
232
-
233
  file_path = file_input if isinstance(file_input, str) else file_input.name
234
-
235
  if file_path.lower().endswith('.pdf'):
236
  try:
237
  pdf = pdfium.PdfDocument(file_path)
@@ -243,6 +200,75 @@ def update_slider(file_input):
243
  else:
244
  return gr.update(maximum=1, value=1)
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  # Create Gradio interface
248
  # with gr.Blocks(title="📖 Image/PDF OCR with LightOnOCR", theme=gr.themes.Soft()) as demo:
@@ -330,41 +356,6 @@ def update_slider(file_input):
330
  # outputs=[output_text, medications_output, raw_output, page_info, rendered_image, num_pages]
331
  # )
332
 
333
- with gr.Blocks(title="💊 Medicine Extraction", theme=gr.themes.Soft()) as demo:
334
- file_input = gr.File(
335
- label="🖼️ Upload Image or PDF",
336
- file_types=[".pdf", ".png", ".jpg", ".jpeg"],
337
- type="filepath"
338
- )
339
- temperature = gr.Slider(
340
- minimum=0.0,
341
- maximum=1.0,
342
- value=0.2,
343
- step=0.05,
344
- label="Temperature",
345
- info="0.0 = deterministic, Higher = more varied"
346
- )
347
- medicines_output = gr.Textbox(
348
- label="💊 Extracted Medicines/Drugs",
349
- placeholder="Medicine/drug names will appear here...",
350
- lines=2,
351
- max_lines=5,
352
- interactive=False,
353
- show_copy_button=True
354
- )
355
- submit_btn = gr.Button("Extract Medicines", variant="primary")
356
-
357
- submit_btn.click(
358
- fn=process_input, # already yields medicines as second output
359
- inputs=[file_input, temperature, 1], # fix page=1 or expose slider
360
- outputs=[gr.update(), medicines_output, gr.update(), gr.update(), gr.update(), gr.update()]
361
- )
362
-
363
-
364
-
365
-
366
- if __name__ == "__main__":
367
- demo.launch()
368
 
369
 
370
 
 
8
 
9
  import gradio as gr
10
  from PIL import Image
11
+ import numpy as np
12
+ import cv2
13
  import pypdfium2 as pdfium
14
  from transformers import (
15
  LightOnOCRForConditionalGeneration,
16
  LightOnOCRProcessor,
17
  )
 
18
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
19
 
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
21
  if device == "cuda":
22
  attn_implementation = "sdpa"
23
  dtype = torch.bfloat16
 
24
  else:
25
  attn_implementation = "eager"
26
  dtype = torch.float32
 
27
 
 
28
  ocr_model = LightOnOCRForConditionalGeneration.from_pretrained(
29
  "lightonai/LightOnOCR-1B-1025",
30
  attn_implementation=attn_implementation,
 
36
  "lightonai/LightOnOCR-1B-1025",
37
  trust_remote_code=True,
38
  )
 
39
 
 
 
40
  ner_tokenizer = AutoTokenizer.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
41
  ner_model = AutoModelForTokenClassification.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
42
  ner_pipeline = pipeline(
 
45
  tokenizer=ner_tokenizer,
46
  aggregation_strategy="simple",
47
  )
 
 
48
 
49
  def render_pdf_page(page, max_resolution=1540, scale=2.77):
 
50
  width, height = page.get_size()
51
  pixel_width = width * scale
52
  pixel_height = height * scale
 
54
  target_scale = scale * resize_factor
55
  return page.render(scale=target_scale, rev_byteorder=True).to_pil()
56
 
 
57
  def process_pdf(pdf_path, page_num=1):
 
58
  pdf = pdfium.PdfDocument(pdf_path)
59
  total_pages = len(pdf)
60
  page_idx = min(max(int(page_num) - 1, 0), total_pages - 1)
 
61
  page = pdf[page_idx]
62
  img = render_pdf_page(page)
 
63
  pdf.close()
64
  return img, total_pages, page_idx + 1
65
 
 
66
  def clean_output_text(text):
 
 
67
  markers_to_remove = ["system", "user", "assistant"]
 
 
68
  lines = text.split('\n')
69
  cleaned_lines = []
 
70
  for line in lines:
71
  stripped = line.strip()
 
72
  if stripped.lower() not in markers_to_remove:
73
  cleaned_lines.append(line)
 
 
74
  cleaned = '\n'.join(cleaned_lines).strip()
 
 
75
  if "assistant" in text.lower():
76
  parts = text.split("assistant", 1)
77
  if len(parts) > 1:
78
  cleaned = parts[1].strip()
 
79
  return cleaned
80
 
81
+ def preprocess_image_for_ocr(image):
82
+ """Convert PIL.Image to adaptive thresholded image for OCR."""
83
+ image_rgb = image.convert("RGB")
84
+ img_np = np.array(image_rgb)
85
+ gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
86
+ adaptive_threshold = cv2.adaptiveThreshold(
87
+ gray,
88
+ 255,
89
+ cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
90
+ cv2.THRESH_BINARY,
91
+ 85,
92
+ 11,
93
+ )
94
+ preprocessed_pil = Image.fromarray(adaptive_threshold)
95
+ return preprocessed_pil
96
 
97
  @spaces.GPU
98
  def extract_text_from_image(image, temperature=0.2):
99
+ """OCR + clinical NER, with preprocessing."""
100
+ processed_img = preprocess_image_for_ocr(image)
101
  chat = [
102
  {
103
  "role": "user",
104
  "content": [
105
+ {"type": "image", "image": processed_img}
106
  ],
107
  }
108
  ]
 
 
109
  inputs = processor.apply_chat_template(
110
  chat,
111
  add_generation_prompt=True,
 
113
  return_dict=True,
114
  return_tensors="pt",
115
  )
 
116
  # Move inputs to device
117
  inputs = {
118
  k: (
 
124
  )
125
  for k, v in inputs.items()
126
  }
 
127
  generation_kwargs = dict(
128
  **inputs,
129
  max_new_tokens=2048,
 
131
  use_cache=True,
132
  do_sample=temperature > 0,
133
  )
 
 
134
  with torch.no_grad():
135
  outputs = ocr_model.generate(**generation_kwargs)
136
 
137
  output_text = processor.decode(outputs[0], skip_special_tokens=True)
138
+ cleaned_text = clean_output_text(output_text)
139
+ entities = ner_pipeline(cleaned_text)
 
 
 
 
 
140
  medications = []
141
  for ent in entities:
142
  if ent["entity_group"] == "treatment":
 
145
  medications[-1] += word[2:]
146
  else:
147
  medications.append(word)
 
148
  medications_str = ", ".join(set(medications)) if medications else "None detected"
149
+ yield cleaned_text, medications_str, output_text, processed_img
 
 
 
 
150
 
151
  def process_input(file_input, temperature, page_num):
 
152
  if file_input is None:
153
+ yield "Please upload an image or PDF first.", "", "", "", "No file!", 1
 
154
  return
155
 
156
  image_to_process = None
157
  page_info = ""
158
  slider_value = page_num
 
159
  file_path = file_input if isinstance(file_input, str) else file_input.name
160
 
 
161
  if file_path.lower().endswith(".pdf"):
162
  try:
163
  image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
 
168
  yield msg, "", msg, "", None, slider_value
169
  return
170
  else:
 
171
  try:
172
  image_to_process = Image.open(file_path)
173
  page_info = "Processing image"
 
177
  return
178
 
179
  try:
180
+ for cleaned_text, medications, raw_md, processed_img in extract_text_from_image(
 
181
  image_to_process, temperature
182
  ):
183
+ yield cleaned_text, medications, raw_md, page_info, processed_img, slider_value
 
 
 
 
 
184
  except Exception as e:
185
  error_msg = f"Error during text extraction: {str(e)}"
186
+ yield error_msg, "", error_msg, page_info, image_to_process, slider_value
 
 
187
 
188
  def update_slider(file_input):
 
189
  if file_input is None:
190
  return gr.update(maximum=20, value=1)
 
191
  file_path = file_input if isinstance(file_input, str) else file_input.name
 
192
  if file_path.lower().endswith('.pdf'):
193
  try:
194
  pdf = pdfium.PdfDocument(file_path)
 
200
  else:
201
  return gr.update(maximum=1, value=1)
202
 
203
+ with gr.Blocks(title="💊 Medicine Extraction", theme=gr.themes.Soft()) as demo:
204
+ file_input = gr.File(
205
+ label="🖼️ Upload Image or PDF",
206
+ file_types=[".pdf", ".png", ".jpg", ".jpeg"],
207
+ type="filepath"
208
+ )
209
+ temperature = gr.Slider(
210
+ minimum=0.0,
211
+ maximum=1.0,
212
+ value=0.2,
213
+ step=0.05,
214
+ label="Temperature"
215
+ )
216
+ page_slider = gr.Slider(
217
+ minimum=1, maximum=20, value=1, step=1,
218
+ label="Page Number (PDF only)",
219
+ interactive=True
220
+ )
221
+ output_text = gr.Textbox(
222
+ label="📝 Extracted Text",
223
+ lines=4,
224
+ max_lines=10,
225
+ interactive=False,
226
+ show_copy_button=True
227
+ )
228
+ medicines_output = gr.Textbox(
229
+ label="💊 Extracted Medicines/Drugs",
230
+ placeholder="Medicine/drug names will appear here...",
231
+ lines=2,
232
+ max_lines=5,
233
+ interactive=False,
234
+ show_copy_button=True
235
+ )
236
+ raw_output = gr.Textbox(
237
+ label="Raw Model Output",
238
+ lines=2,
239
+ max_lines=5,
240
+ interactive=False
241
+ )
242
+ page_info = gr.Markdown(
243
+ value="", # Info of PDF page
244
+ interactive=False
245
+ )
246
+ rendered_image = gr.Image(
247
+ label="Processed Image (Thresholded for OCR)",
248
+ interactive=False
249
+ )
250
+ num_pages = gr.Number(
251
+ value=1, label="Current Page (slider)", visible=False
252
+ )
253
+ submit_btn = gr.Button("Extract Medicines", variant="primary")
254
+
255
+ submit_btn.click(
256
+ fn=process_input,
257
+ inputs=[file_input, temperature, page_slider],
258
+ outputs=[output_text, medicines_output, raw_output, page_info, rendered_image, num_pages]
259
+ )
260
+
261
+ file_input.change(
262
+ fn=update_slider,
263
+ inputs=[file_input],
264
+ outputs=[page_slider]
265
+ )
266
+
267
+ if __name__ == "__main__":
268
+ demo.launch()
269
+
270
+
271
+
272
 
273
  # Create Gradio interface
274
  # with gr.Blocks(title="📖 Image/PDF OCR with LightOnOCR", theme=gr.themes.Soft()) as demo:
 
356
  # outputs=[output_text, medications_output, raw_output, page_info, rendered_image, num_pages]
357
  # )
358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
 
361