ar08 commited on
Commit
17947a0
·
verified ·
1 Parent(s): 154213a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -21
app.py CHANGED
@@ -1,44 +1,53 @@
1
  # Required: pip install gradio transformers accelerate optimum onnxruntime onnx
2
 
3
  import gradio as gr
 
4
  from transformers import AutoTokenizer
5
  from optimum.onnxruntime import ORTModelForSeq2SeqLM
6
  from optimum.pipelines import pipeline
7
 
8
- # Load ONNX model and tokenizer
9
  model_name = "Rahmat82/t5-small-finetuned-summarization-xsum"
10
- model = ORTModelForSeq2SeqLM.from_pretrained(model_name, export=True)
11
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
12
 
13
- # Create summarizer pipeline
14
- summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device_map="auto", batch_size=12)
 
 
 
 
 
 
 
15
 
16
- # Summarization function with max input tokens and medium summary length
17
  def summarize_text(text):
18
- if not text.strip():
 
19
  return "Please enter some text."
20
 
21
- # Tokenize and truncate to 1024 tokens
22
- inputs = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
23
- input_text = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
24
 
25
- # Generate medium-length summary
26
- result = summarizer(
27
- input_text,
28
- min_length=100, # 👈 medium minimum length
29
- max_length=120, # 👈 medium maximum length
30
  do_sample=False
31
  )
32
- return result[0]["summary_text"]
33
 
34
- # Gradio app
35
  app = gr.Interface(
36
  fn=summarize_text,
37
- inputs=gr.Textbox(lines=15, placeholder="Paste your text here...", label="Input Text"),
38
  outputs=gr.Textbox(label="Summary"),
39
- title="🚀 ONNX-Powered T5 Summarizer (Medium Summary)",
40
- description="Summarize long text into a medium-length summary using an ONNX-accelerated T5-small model (max input: 1024 tokens)"
41
  )
42
 
43
- # Launch
44
- app.launch()
 
1
  # Required: pip install gradio transformers accelerate optimum onnxruntime onnx
2
 
3
  import gradio as gr
4
+ import torch
5
  from transformers import AutoTokenizer
6
  from optimum.onnxruntime import ORTModelForSeq2SeqLM
7
  from optimum.pipelines import pipeline
8
 
9
+ # Load ONNX-optimized model and tokenizer
10
  model_name = "Rahmat82/t5-small-finetuned-summarization-xsum"
11
+ model = ORTModelForSeq2SeqLM.from_pretrained(model_name)
12
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
13
 
14
+ # Build a fast summarization pipeline
15
+ device = 0 if torch.cuda.is_available() else -1
16
+ summarizer = pipeline(
17
+ task="summarization",
18
+ model=model,
19
+ tokenizer=tokenizer,
20
+ device=device,
21
+ batch_size=16, # increased batch size for higher throughput
22
+ )
23
 
24
+ # Speed-optimized summarization function
25
  def summarize_text(text):
26
+ text = text.strip()
27
+ if not text:
28
  return "Please enter some text."
29
 
30
+ # Encode with truncation (max_length=1024)
31
+ inputs = tokenizer.encode(text, max_length=1024, truncation=True, return_tensors="pt")
32
+ decoded_input = tokenizer.decode(inputs[0], skip_special_tokens=True)
33
 
34
+ # Generate summary with tighter bounds
35
+ summary = summarizer(
36
+ decoded_input,
37
+ min_length=69, # lower min length for faster generation
38
+ max_length=120,
39
  do_sample=False
40
  )
41
+ return summary[0]["summary_text"]
42
 
43
+ # Gradio interface
44
  app = gr.Interface(
45
  fn=summarize_text,
46
+ inputs=gr.Textbox(lines=12, placeholder="Paste long text here...", label="Input Text"),
47
  outputs=gr.Textbox(label="Summary"),
48
+ title=" Fast ONNX T5 Summarizer",
49
+ description="ONNX-accelerated T5-small model for quick, medium-length summarization (up to 1,024 tokens)."
50
  )
51
 
52
+ if __name__ == "__main__":
53
+ app.launch()