ar08 commited on
Commit
2716805
·
verified ·
1 Parent(s): 17947a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -22
app.py CHANGED
@@ -1,52 +1,70 @@
1
- # Required: pip install gradio transformers accelerate optimum onnxruntime onnx
2
 
3
  import gradio as gr
4
  import torch
5
  from transformers import AutoTokenizer
6
- from optimum.onnxruntime import ORTModelForSeq2SeqLM
7
- from optimum.pipelines import pipeline
 
8
 
9
- # Load ONNX-optimized model and tokenizer
10
  model_name = "Rahmat82/t5-small-finetuned-summarization-xsum"
11
- model = ORTModelForSeq2SeqLM.from_pretrained(model_name)
12
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Build a fast summarization pipeline
15
- device = 0 if torch.cuda.is_available() else -1
 
 
 
 
 
 
 
 
16
  summarizer = pipeline(
17
  task="summarization",
18
  model=model,
19
  tokenizer=tokenizer,
20
- device=device,
21
- batch_size=16, # increased batch size for higher throughput
 
 
22
  )
23
 
24
- # Speed-optimized summarization function
25
  def summarize_text(text):
26
  text = text.strip()
27
  if not text:
28
  return "Please enter some text."
29
-
30
- # Encode with truncation (max_length=1024)
31
  inputs = tokenizer.encode(text, max_length=1024, truncation=True, return_tensors="pt")
32
- decoded_input = tokenizer.decode(inputs[0], skip_special_tokens=True)
33
-
34
- # Generate summary with tighter bounds
35
  summary = summarizer(
36
- decoded_input,
37
- min_length=69, # lower min length for faster generation
38
  max_length=120,
39
  do_sample=False
40
  )
41
  return summary[0]["summary_text"]
42
 
43
- # Gradio interface
44
  app = gr.Interface(
45
  fn=summarize_text,
46
- inputs=gr.Textbox(lines=12, placeholder="Paste long text here...", label="Input Text"),
47
  outputs=gr.Textbox(label="Summary"),
48
- title=" Fast ONNX T5 Summarizer",
49
- description="ONNX-accelerated T5-small model for quick, medium-length summarization (up to 1,024 tokens)."
50
  )
51
 
52
  if __name__ == "__main__":
 
1
+ # pip install gradio transformers onnxruntime optimum torch
2
 
3
  import gradio as gr
4
  import torch
5
  from transformers import AutoTokenizer
6
+ from optimum.onnxruntime import ORTModelForSeq2SeqLM, ORTOptimizer, ORTQuantizer
7
+ from optimum.onnxruntime.configuration import AutoOptimizationConfig
8
+ import onnxruntime as ort
9
 
10
+ # Step 1: Load & optimize the ONNX model
11
  model_name = "Rahmat82/t5-small-finetuned-summarization-xsum"
12
+ model = ORTModelForSeq2SeqLM.from_pretrained(model_name, export=True)
13
+
14
+ optimizer = ORTOptimizer.from_pretrained(model)
15
+ opt_config = AutoOptimizationConfig.O2() # graph fusions and transformer-specific optimizations
16
+ optimizer.optimize(save_dir="optimized_model", optimization_config=opt_config)
17
+ optimized_model = ORTModelForSeq2SeqLM.from_pretrained("optimized_model")
18
+
19
+ # Step 2: Apply dynamic INT8 quantization for CPU
20
+ quantizer = ORTQuantizer.from_pretrained(optimized_model)
21
+ opt_q = quantizer.quantize(
22
+ save_dir="quantized_model",
23
+ quantization_config=AutoOptimizationConfig.O2().quantization_config, # dynamic quant
24
+ )
25
+ model = ORTModelForSeq2SeqLM.from_pretrained("quantized_model")
26
 
27
+ # Step 3: Set up ONNXRuntime Session options for CPU multi-threading
28
+ sess_options = ort.SessionOptions()
29
+ sess_options.intra_op_num_threads = min(4, torch.get_num_threads()) # 4 threads for inference
30
+ sess_options.inter_op_num_threads = 1
31
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
32
+
33
+ # Rebuild pipeline with optimized quantized model on CPU
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
35
+ summarizer = gradio_pipeline = None
36
+ from optimum.pipelines import pipeline
37
  summarizer = pipeline(
38
  task="summarization",
39
  model=model,
40
  tokenizer=tokenizer,
41
+ framework="pt",
42
+ ort_session_options=sess_options,
43
+ device=-1,
44
+ batch_size=8,
45
  )
46
 
 
47
  def summarize_text(text):
48
  text = text.strip()
49
  if not text:
50
  return "Please enter some text."
 
 
51
  inputs = tokenizer.encode(text, max_length=1024, truncation=True, return_tensors="pt")
52
+ decoded = tokenizer.decode(inputs[0], skip_special_tokens=True)
 
 
53
  summary = summarizer(
54
+ decoded,
55
+ min_length=60,
56
  max_length=120,
57
  do_sample=False
58
  )
59
  return summary[0]["summary_text"]
60
 
61
+ # Gradio UI
62
  app = gr.Interface(
63
  fn=summarize_text,
64
+ inputs=gr.Textbox(lines=12, label="Input Text"),
65
  outputs=gr.Textbox(label="Summary"),
66
+ title="⚙️ CPU-Optimized ONNX T5 Summarizer",
67
+ description="Uses graph optimizations, INT8 quantization, and threading tweaks for fast CPU performance."
68
  )
69
 
70
  if __name__ == "__main__":