LutherYTT commited on
Commit
c548aa4
·
1 Parent(s): a8eefc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py CHANGED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from safetensors.torch import load_file
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import gc
7
+
8
+ # 清理内存
9
+ gc.collect()
10
+ torch.cuda.empty_cache()
11
+
12
+ # 1. 定义MultiTaskRoberta模型架构
13
+ class MultiTaskRoberta(nn.Module):
14
+ def __init__(self, base_model):
15
+ super().__init__()
16
+ self.roberta = base_model
17
+ self.classifier = nn.Linear(768, 3)
18
+ self.regressor = nn.Linear(768, 5)
19
+
20
+ def forward(self, input_ids, attention_mask=None, **kwargs):
21
+ outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
22
+ pooled = outputs.last_hidden_state[:, 0]
23
+ logits = self.classifier(pooled)
24
+ regs = self.regressor(pooled)
25
+ return {"logits": logits, "regression_outputs": regs}
26
+
27
+ # 2. 准备模型和tokenizer
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ print(f"使用设备: {device}")
30
+
31
+ # 加载tokenizer
32
+ tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
33
+
34
+ # 加载模型
35
+ base_model = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
36
+ model = MultiTaskRoberta(base_model)
37
+
38
+ # 加载权重
39
+ model_path = "/content/robert-distilled-model/model.safetensors"
40
+ state_dict = load_file(model_path, device="cpu")
41
+ model.load_state_dict(state_dict)
42
+ model.to(device)
43
+ model.eval()
44
+
45
+ # 使用半精度减少内存占用
46
+ # if device.type == 'cuda':
47
+ # model.half()
48
+ # print("使用半精度模型")
49
+
50
+ # 3. 优化后的推理函数
51
+ def predict(text: str):
52
+ try:
53
+ inputs = tokenizer(
54
+ text,
55
+ return_tensors="pt",
56
+ truncation=True,
57
+ padding="max_length",
58
+ max_length=128
59
+ )
60
+
61
+ # 将输入移到设备
62
+ inputs = {k: v.to(device) for k, v in inputs.items()}
63
+
64
+ with torch.no_grad():
65
+ if device.type == 'cuda':
66
+ with torch.cuda.amp.autocast():
67
+ out = model(**inputs)
68
+ else:
69
+ out = model(**inputs)
70
+
71
+ pred_class = torch.argmax(out["logits"], dim=-1).item()
72
+ sentiment_map = {0: "正面", 1: "負面", 2: "中立"}
73
+
74
+ # 将结果移回CPU处理
75
+ reg_results = out["regression_outputs"][0].cpu().numpy()
76
+ rating, delight, anger, sorrow, happiness = reg_results
77
+
78
+ return {
79
+ "情感": sentiment_map[pred_class],
80
+ "評分": round(rating, 2),
81
+ "喜悅": round(delight, 2),
82
+ "憤怒": round(anger, 2),
83
+ "悲傷": round(sorrow, 2),
84
+ "快樂": round(happiness, 2),
85
+ }
86
+ except Exception as e:
87
+ return {"错误": f"处理失败: {str(e)}"}
88
+
89
+ # 4. 创建Gradio界面
90
+ iface = gr.Interface(
91
+ fn=predict,
92
+ inputs=gr.Textbox(lines=3, placeholder="請輸入粵語文本...", label="粵語文本"),
93
+ outputs=gr.JSON(label="分析結果"),
94
+ title="粵語情感與情緒分析",
95
+ description="輸入粵語文本,分析情感(正面/負面/中立)和五種情緒評分",
96
+ examples=[
97
+ ["呢個plan聽落唔錯,我哋試下先啦。"],
98
+ ["份proposal 你send 咗俾client未?Deadline 係EOD呀。"],
99
+ ["返工返到好攰,但係見到同事就feel better啲。"],
100
+ ["你今次嘅presentation做得唔錯,我好 impressed!"],
101
+ ["夜晚聽到嗰啲聲,我唔敢出房門。"],
102
+ ["個client 真係好 difficult 囉,改咗n 次 requirements,仲要urgent,chur 到痴線!"],
103
+ ["我尋日冇乜特別事做,就係喺屋企睇電視。"],
104
+ ["Weekend 去staycation,間酒店個view 正到爆!"],
105
+ ["做乜嘢都冇意義。"],
106
+ ["今朝遲到咗,差啲miss咗個重要meeting"],
107
+
108
+ ]
109
+ )
110
+
111
+ # 5. 启动应用 - 使用兼容的启动方式
112
+ if __name__ == "__main__":
113
+ iface.launch(share=True, show_error=True)