ar08's picture
Update app.py
3c0c86a verified
raw
history blame
1.49 kB
# Required: pip install gradio transformers accelerate optimum onnxruntime onnx
import gradio as gr
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForSeq2SeqLM
from optimum.pipelines import pipeline
# Load ONNX model and tokenizer
model_name = "Rahmat82/t5-small-finetuned-summarization-xsum"
model = ORTModelForSeq2SeqLM.from_pretrained(model_name, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
# Create summarizer pipeline with Optimum
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device_map="auto", batch_size=12)
# Define summarization function with 1024 token cap
def summarize_text(text):
if not text.strip():
return "Please enter some text."
# Tokenize and truncate to max 1024 tokens
inputs = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
input_text = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
# Summarize the truncated text
result = summarizer(input_text)
return result[0]["summary_text"]
# Gradio app
app = gr.Interface(
fn=summarize_text,
inputs=gr.Textbox(lines=15, placeholder="Paste your text here...", label="Input Text"),
outputs=gr.Textbox(label="Summary"),
title="🚀 ONNX-Powered T5 Summarizer (1024 tokens)",
description="Summarize long text using a fine-tuned ONNX-accelerated T5-small model (max input: 1024 tokens)"
)
# Launch the app
app.launch()