IFMedTechdemo commited on
Commit
f927c87
ยท
verified ยท
1 Parent(s): 8cf36b4

Create LightOCR_ClinicalNER_Test

Browse files
Files changed (1) hide show
  1. LightOCR_ClinicalNER_Test +356 -0
LightOCR_ClinicalNER_Test ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import subprocess
3
+ import sys
4
+ import threading
5
+
6
+ import spaces
7
+ import torch
8
+
9
+ import gradio as gr
10
+ from PIL import Image
11
+ from io import BytesIO
12
+ import pypdfium2 as pdfium
13
+ from transformers import (
14
+ LightOnOCRForConditionalGeneration,
15
+ LightOnOCRProcessor,
16
+ TextIteratorStreamer,
17
+ )
18
+
19
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
20
+
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ # Choose best attention implementation based on device
24
+ if device == "cuda":
25
+ attn_implementation = "sdpa"
26
+ dtype = torch.bfloat16
27
+ print("Using sdpa for GPU")
28
+ else:
29
+ attn_implementation = "eager" # Best for CPU
30
+ dtype = torch.float32
31
+ print("Using eager attention for CPU")
32
+
33
+ # Initialize the LightOnOCR model and processor
34
+ print(f"Loading model on {device} with {attn_implementation} attention...")
35
+ model = LightOnOCRForConditionalGeneration.from_pretrained(
36
+ "lightonai/LightOnOCR-1B-1025",
37
+ attn_implementation=attn_implementation,
38
+ torch_dtype=dtype,
39
+ trust_remote_code=True
40
+ ).to(device).eval()
41
+
42
+ processor = LightOnOCRProcessor.from_pretrained(
43
+ "lightonai/LightOnOCR-1B-1025",
44
+ trust_remote_code=True
45
+ )
46
+ print("Model loaded successfully!")
47
+
48
+
49
+ def render_pdf_page(page, max_resolution=1540, scale=2.77):
50
+ """Render a PDF page to PIL Image."""
51
+ width, height = page.get_size()
52
+ pixel_width = width * scale
53
+ pixel_height = height * scale
54
+ resize_factor = min(1, max_resolution / pixel_width, max_resolution / pixel_height)
55
+ target_scale = scale * resize_factor
56
+ return page.render(scale=target_scale, rev_byteorder=True).to_pil()
57
+
58
+
59
+ def process_pdf(pdf_path, page_num=1):
60
+ """Extract a specific page from PDF."""
61
+ pdf = pdfium.PdfDocument(pdf_path)
62
+ total_pages = len(pdf)
63
+ page_idx = min(max(int(page_num) - 1, 0), total_pages - 1)
64
+
65
+ page = pdf[page_idx]
66
+ img = render_pdf_page(page)
67
+
68
+ pdf.close()
69
+ return img, total_pages, page_idx + 1
70
+
71
+
72
+ def clean_output_text(text):
73
+ """Remove chat template artifacts from output."""
74
+ # Remove common chat template markers
75
+ markers_to_remove = ["system", "user", "assistant"]
76
+
77
+ # Split by lines and filter
78
+ lines = text.split('\n')
79
+ cleaned_lines = []
80
+
81
+ for line in lines:
82
+ stripped = line.strip()
83
+ # Skip lines that are just template markers
84
+ if stripped.lower() not in markers_to_remove:
85
+ cleaned_lines.append(line)
86
+
87
+ # Join back and strip leading/trailing whitespace
88
+ cleaned = '\n'.join(cleaned_lines).strip()
89
+
90
+ # Alternative approach: if there's an "assistant" marker, take everything after it
91
+ if "assistant" in text.lower():
92
+ parts = text.split("assistant", 1)
93
+ if len(parts) > 1:
94
+ cleaned = parts[1].strip()
95
+
96
+ return cleaned
97
+
98
+
99
+ @spaces.GPU
100
+ def extract_text_from_image(image, temperature=0.2, stream=False):
101
+ """Extract text from image using LightOnOCR model."""
102
+ # Prepare the chat format
103
+ chat = [
104
+ {
105
+ "role": "user",
106
+ "content": [
107
+ {"type": "image", "url": image},
108
+ ],
109
+ }
110
+ ]
111
+
112
+ # Apply chat template and tokenize
113
+ inputs = processor.apply_chat_template(
114
+ chat,
115
+ add_generation_prompt=True,
116
+ tokenize=True,
117
+ return_dict=True,
118
+ return_tensors="pt"
119
+ )
120
+
121
+ # Move inputs to device AND convert to the correct dtype
122
+ inputs = {
123
+ k: v.to(device=device, dtype=dtype) if isinstance(v, torch.Tensor) and v.dtype in [torch.float32, torch.float16, torch.bfloat16]
124
+ else v.to(device) if isinstance(v, torch.Tensor)
125
+ else v
126
+ for k, v in inputs.items()
127
+ }
128
+
129
+ generation_kwargs = dict(
130
+ **inputs,
131
+ max_new_tokens=2048,
132
+ temperature=temperature if temperature > 0 else 0.0,
133
+ use_cache=True,
134
+ do_sample=temperature > 0,
135
+ )
136
+
137
+ if stream:
138
+ # Setup streamer for streaming generation
139
+ streamer = TextIteratorStreamer(
140
+ processor.tokenizer,
141
+ skip_prompt=True,
142
+ skip_special_tokens=True
143
+ )
144
+ generation_kwargs["streamer"] = streamer
145
+
146
+ # Run generation in a separate thread
147
+ thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
148
+ thread.start()
149
+
150
+ # Yield chunks as they arrive
151
+ full_text = ""
152
+ for new_text in streamer:
153
+ full_text += new_text
154
+ # Clean the accumulated text
155
+ cleaned_text = clean_output_text(full_text)
156
+ yield cleaned_text
157
+
158
+ thread.join()
159
+ else:
160
+ # Non-streaming generation
161
+ with torch.no_grad():
162
+ outputs = model.generate(**generation_kwargs)
163
+
164
+ # Decode the output
165
+ output_text = processor.decode(outputs[0], skip_special_tokens=True)
166
+
167
+ # Clean the output
168
+ cleaned_text = clean_output_text(output_text)
169
+
170
+ ######### clinical NER ##############
171
+
172
+ tokenizer = AutoTokenizer.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
173
+ model = AutoModelForTokenClassification.from_pretrained("samrawal/bert-base-uncased_clinical-ner")
174
+ ner = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
175
+
176
+
177
+ Clinical NER process
178
+ entities = ner(cleaned_text)
179
+ medications = []
180
+ for ent in entities:
181
+ if ent["entity_group"] == "treatment":
182
+ word = ent["word"]
183
+ if word.startswith("##") and medications:
184
+ medications[-1] += word[2:]
185
+ else:
186
+ medications.append(word)
187
+ medications_str = ", ".join(set(medications)) if medications else "None detected"
188
+
189
+ yield cleaned_text
190
+ yield medications_s
191
+
192
+
193
+
194
+
195
+ def process_input(file_input, temperature, page_num, enable_streaming):
196
+ """Process uploaded file (image or PDF) and extract text with optional streaming."""
197
+ if file_input is None:
198
+ yield "Please upload an image or PDF first.", "", "", None, gr.update()
199
+ return
200
+
201
+ image_to_process = None
202
+ page_info = ""
203
+
204
+ file_path = file_input if isinstance(file_input, str) else file_input.name
205
+
206
+ # Handle PDF files
207
+ if file_path.lower().endswith('.pdf'):
208
+ try:
209
+ image_to_process, total_pages, actual_page = process_pdf(file_path, int(page_num))
210
+ page_info = f"Processing page {actual_page} of {total_pages}"
211
+ except Exception as e:
212
+ yield f"Error processing PDF: {str(e)}", "", "", None, gr.update()
213
+ return
214
+ # Handle image files
215
+ else:
216
+ try:
217
+ image_to_process = Image.open(file_path)
218
+ page_info = "Processing image"
219
+ except Exception as e:
220
+ yield f"Error opening image: {str(e)}", "", "", None, gr.update()
221
+ return
222
+
223
+ try:
224
+ # Extract text using LightOnOCR with optional streaming
225
+ for extracted_text, medications in extract_text_from_image(image_to_process, temperature, stream=enable_streaming):
226
+ yield extracted_text, medications, page_info, image_to_process, gr.update()
227
+
228
+ except Exception as e:
229
+ error_msg = f"Error during text extraction: {str(e)}"
230
+ yield error_msg, error_msg, page_info, image_to_process, gr.update()
231
+
232
+
233
+ def update_slider(file_input):
234
+ """Update page slider based on PDF page count."""
235
+ if file_input is None:
236
+ return gr.update(maximum=20, value=1)
237
+
238
+ file_path = file_input if isinstance(file_input, str) else file_input.name
239
+
240
+ if file_path.lower().endswith('.pdf'):
241
+ try:
242
+ pdf = pdfium.PdfDocument(file_path)
243
+ total_pages = len(pdf)
244
+ pdf.close()
245
+ return gr.update(maximum=total_pages, value=1)
246
+ except:
247
+ return gr.update(maximum=20, value=1)
248
+ else:
249
+ return gr.update(maximum=1, value=1)
250
+
251
+
252
+ # Create Gradio interface
253
+ with gr.Blocks(title="๐Ÿ“– Image/PDF OCR with LightOnOCR", theme=gr.themes.Soft()) as demo:
254
+ gr.Markdown(f"""
255
+ # ๐Ÿ“– Image/PDF to Text Extraction with LightOnOCR
256
+
257
+ **๐Ÿ’ก How to use:**
258
+ 1. Upload an image or PDF
259
+ 2. For PDFs: select which page to extract (1-20)
260
+ 3. Adjust temperature if needed
261
+ 4. Click "Extract Text"
262
+
263
+ **Note:** The Markdown rendering for tables may not always be perfect. Check the raw output for complex tables!
264
+
265
+ **Model:** LightOnOCR-1B-1025 by LightOn AI
266
+ **Device:** {device.upper()}
267
+ **Attention:** {attn_implementation}
268
+ """)
269
+
270
+ with gr.Row():
271
+ with gr.Column(scale=1):
272
+ file_input = gr.File(
273
+ label="๐Ÿ–ผ๏ธ Upload Image or PDF",
274
+ file_types=[".pdf", ".png", ".jpg", ".jpeg"],
275
+ type="filepath"
276
+ )
277
+ rendered_image = gr.Image(
278
+ label="๐Ÿ“„ Preview",
279
+ type="pil",
280
+ height=400,
281
+ interactive=False
282
+ )
283
+ num_pages = gr.Slider(
284
+ minimum=1,
285
+ maximum=20,
286
+ value=1,
287
+ step=1,
288
+ label="PDF: Page Number",
289
+ info="Select which page to extract"
290
+ )
291
+ page_info = gr.Textbox(
292
+ label="Processing Info",
293
+ value="",
294
+ interactive=False
295
+ )
296
+ temperature = gr.Slider(
297
+ minimum=0.0,
298
+ maximum=1.0,
299
+ value=0.2,
300
+ step=0.05,
301
+ label="Temperature",
302
+ info="0.0 = deterministic, Higher = more varied"
303
+ )
304
+ enable_streaming = gr.Checkbox(
305
+ label="Enable Streaming",
306
+ value=True,
307
+ info="Show text progressively as it's generated"
308
+ )
309
+ submit_btn = gr.Button("Extract Text", variant="primary")
310
+ clear_btn = gr.Button("Clear", variant="secondary")
311
+
312
+ with gr.Column(scale=2):
313
+ output_text = gr.Markdown(
314
+ label="๐Ÿ“„ Extracted Text (Rendered)",
315
+ value="*Extracted text will appear here...*"
316
+ )
317
+ medications_output = gr.Textbox(
318
+ label="๐Ÿ’Š Extracted Medicines/Drugs",
319
+ placeholder="Medicine/drug names will appear here...",
320
+ lines=2,
321
+ max_lines=5,
322
+ interactive=False,
323
+ show_copy_button=True
324
+ )
325
+
326
+ with gr.Row():
327
+ with gr.Column():
328
+ raw_output = gr.Textbox(
329
+ label="Raw Markdown Output",
330
+ placeholder="Raw text will appear here...",
331
+ lines=20,
332
+ max_lines=30,
333
+ show_copy_button=True
334
+ )
335
+
336
+ # Event handlers
337
+ submit_btn.click(
338
+ fn=process_input,
339
+ inputs=[file_input, temperature, num_pages, enable_streaming],
340
+ outputs=[output_text, medications_output, raw_output, page_info, rendered_image, num_pages]
341
+ )
342
+
343
+ file_input.change(
344
+ fn=update_slider,
345
+ inputs=[file_input],
346
+ outputs=[num_pages]
347
+ )
348
+
349
+ clear_btn.click(
350
+ fn=lambda: (None, "*Extracted text will appear here...*", "", "", None, 1),
351
+ outputs=[file_input, output_text, raw_output, page_info, rendered_image, num_pages]
352
+ )
353
+
354
+
355
+ if __name__ == "__main__":
356
+ demo.launch()