# -*- coding: utf-8 -*- """ Aphasia classification inference (cleaned). - Respects model_dir argument - Correctly parses durations like ["word", 300] and [start, end] - Removes duplicate load_state_dict - Adds predict_from_chajson(json_path, ...) helper """ import json as json import os import math from dataclasses import dataclass from typing import Dict, List, Optional, Tuple from collections import defaultdict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import pandas as pd from transformers import AutoTokenizer, AutoModel # ========================= # Model definition (unchanged shape) # ========================= @dataclass class ModelConfig: model_name: str = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" max_length: int = 512 hidden_size: int = 768 pos_vocab_size: int = 150 pos_emb_dim: int = 64 grammar_dim: int = 3 grammar_hidden_dim: int = 64 duration_hidden_dim: int = 128 prosody_dim: int = 32 num_attention_heads: int = 8 attention_dropout: float = 0.3 classifier_hidden_dims: List[int] = None dropout_rate: float = 0.3 def __post_init__(self): if self.classifier_hidden_dims is None: self.classifier_hidden_dims = [512, 256] class StablePositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int = 5000): super().__init__() self.d_model = d_model pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01) def forward(self, x): seq_len = x.size(1) sinusoidal = self.pe[:, :seq_len, :].to(x.device) learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1) return x + 0.1 * (sinusoidal + learnable) class StableMultiHeadAttention(nn.Module): def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3): super().__init__() self.num_heads = num_heads self.feature_dim = feature_dim self.head_dim = feature_dim // num_heads assert feature_dim % num_heads == 0 self.query = nn.Linear(feature_dim, feature_dim) self.key = nn.Linear(feature_dim, feature_dim) self.value = nn.Linear(feature_dim, feature_dim) self.dropout = nn.Dropout(dropout) self.output_proj = nn.Linear(feature_dim, feature_dim) self.layer_norm = nn.LayerNorm(feature_dim) def forward(self, x, mask=None): b, t, _ = x.size() Q = self.query(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2) K = self.key(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2) V = self.value(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) if mask is not None: if mask.dim() == 2: mask = mask.unsqueeze(1).unsqueeze(1) scores.masked_fill_(mask == 0, -1e9) attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) ctx = torch.matmul(attn, V) ctx = ctx.transpose(1, 2).contiguous().view(b, t, self.feature_dim) out = self.output_proj(ctx) return self.layer_norm(out + x) class StableLinguisticFeatureExtractor(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.config = config self.pos_embedding = nn.Embedding(config.pos_vocab_size, config.pos_emb_dim, padding_idx=0) self.pos_attention = StableMultiHeadAttention(config.pos_emb_dim, num_heads=4) self.grammar_projection = nn.Sequential( nn.Linear(config.grammar_dim, config.grammar_hidden_dim), nn.Tanh(), nn.LayerNorm(config.grammar_hidden_dim), nn.Dropout(config.dropout_rate * 0.3) ) self.duration_projection = nn.Sequential( nn.Linear(1, config.duration_hidden_dim), nn.Tanh(), nn.LayerNorm(config.duration_hidden_dim) ) self.prosody_projection = nn.Sequential( nn.Linear(config.prosody_dim, config.prosody_dim), nn.ReLU(), nn.LayerNorm(config.prosody_dim) ) total_feature_dim = (config.pos_emb_dim + config.grammar_hidden_dim + config.duration_hidden_dim + config.prosody_dim) self.feature_fusion = nn.Sequential( nn.Linear(total_feature_dim, total_feature_dim // 2), nn.Tanh(), nn.LayerNorm(total_feature_dim // 2), nn.Dropout(config.dropout_rate) ) def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask): b, t = pos_ids.size() pos_ids = pos_ids.clamp(0, self.config.pos_vocab_size - 1) pos_emb = self.pos_embedding(pos_ids) pos_feat = self.pos_attention(pos_emb, attention_mask) gra_feat = self.grammar_projection(grammar_ids.float()) dur_feat = self.duration_projection(durations.unsqueeze(-1).float()) pro_feat = self.prosody_projection(prosody_features.float()) combined = torch.cat([pos_feat, gra_feat, dur_feat, pro_feat], dim=-1) fused = self.feature_fusion(combined) mask_exp = attention_mask.unsqueeze(-1).float() pooled = torch.sum(fused * mask_exp, dim=1) / torch.sum(mask_exp, dim=1) return pooled class StableAphasiaClassifier(nn.Module): def __init__(self, config: ModelConfig, num_labels: int): super().__init__() self.config = config self.num_labels = num_labels self.bert = AutoModel.from_pretrained(config.model_name) self.bert_config = self.bert.config self.positional_encoder = StablePositionalEncoding(d_model=self.bert_config.hidden_size, max_len=config.max_length) self.linguistic_extractor = StableLinguisticFeatureExtractor(config) bert_dim = self.bert_config.hidden_size lingu_dim = (config.pos_emb_dim + config.grammar_hidden_dim + config.duration_hidden_dim + config.prosody_dim) // 2 self.feature_fusion = nn.Sequential( nn.Linear(bert_dim + lingu_dim, bert_dim), nn.LayerNorm(bert_dim), nn.Tanh(), nn.Dropout(config.dropout_rate) ) self.classifier = self._build_classifier(bert_dim, num_labels) self.severity_head = nn.Sequential(nn.Linear(bert_dim, 4), nn.Softmax(dim=-1)) self.fluency_head = nn.Sequential(nn.Linear(bert_dim, 1), nn.Sigmoid()) def _build_classifier(self, input_dim: int, num_labels: int): layers, cur = [], input_dim for h in self.config.classifier_hidden_dims: layers += [nn.Linear(cur, h), nn.LayerNorm(h), nn.Tanh(), nn.Dropout(self.config.dropout_rate)] cur = h layers.append(nn.Linear(cur, num_labels)) return nn.Sequential(*layers) def _attention_pooling(self, seq_out, attn_mask): attn_w = torch.softmax(torch.sum(seq_out, dim=-1, keepdim=True), dim=1) attn_w = attn_w * attn_mask.unsqueeze(-1).float() attn_w = attn_w / (torch.sum(attn_w, dim=1, keepdim=True) + 1e-9) return torch.sum(seq_out * attn_w, dim=1) def forward(self, input_ids, attention_mask, labels=None, word_pos_ids=None, word_grammar_ids=None, word_durations=None, prosody_features=None, **kwargs): bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask) seq_out = bert_out.last_hidden_state pos_enh = self.positional_encoder(seq_out) pooled = self._attention_pooling(pos_enh, attention_mask) if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]): if prosody_features is None: b, t = input_ids.size() prosody_features = torch.zeros(b, t, self.config.prosody_dim, device=input_ids.device) ling = self.linguistic_extractor(word_pos_ids, word_grammar_ids, word_durations, prosody_features, attention_mask) else: ling = torch.zeros(input_ids.size(0), (self.config.pos_emb_dim + self.config.grammar_hidden_dim + self.config.duration_hidden_dim + self.config.prosody_dim) // 2, device=input_ids.device) fused = self.feature_fusion(torch.cat([pooled, ling], dim=1)) logits = self.classifier(fused) severity_pred = self.severity_head(fused) fluency_pred = self.fluency_head(fused) return {"logits": logits, "severity_pred": severity_pred, "fluency_pred": fluency_pred, "loss": None} # ========================= # Inference system (fixed wiring) # ========================= class AphasiaInferenceSystem: """失語症分類推理系統""" def __init__(self, model_dir: str): self.model_dir = model_dir # <— honor the argument self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Descriptions (unchanged) self.aphasia_descriptions = { "BROCA": {"name": "Broca's Aphasia (Non-fluent)", "description": "Characterized by limited speech output, difficulty with grammar and sentence formation, but relatively preserved comprehension. Speech is typically effortful and halting.", "features": ["Non-fluent speech", "Preserved comprehension", "Grammar difficulties", "Word-finding problems"]}, "TRANSMOTOR": {"name": "Trans-cortical Motor Aphasia", "description": "Similar to Broca's aphasia but with preserved repetition abilities. Speech is non-fluent with good comprehension.", "features": ["Non-fluent speech", "Good repetition", "Preserved comprehension", "Grammar difficulties"]}, "NOTAPHASICBYWAB": {"name": "Not Aphasic by WAB", "description": "Individuals who do not meet the criteria for aphasia according to the Western Aphasia Battery assessment.", "features": ["Normal language function", "No significant language impairment", "Good comprehension", "Fluent speech"]}, "CONDUCTION": {"name": "Conduction Aphasia", "description": "Characterized by fluent speech with good comprehension but severely impaired repetition. Often involves phonemic paraphasias.", "features": ["Fluent speech", "Good comprehension", "Poor repetition", "Phonemic errors"]}, "WERNICKE": {"name": "Wernicke's Aphasia (Fluent)", "description": "Fluent but often meaningless speech with poor comprehension. Speech may contain neologisms and jargon.", "features": ["Fluent speech", "Poor comprehension", "Jargon speech", "Neologisms"]}, "ANOMIC": {"name": "Anomic Aphasia", "description": "Primarily characterized by word-finding difficulties with otherwise relatively preserved language abilities.", "features": ["Word-finding difficulties", "Good comprehension", "Fluent speech", "Circumlocution"]}, "GLOBAL": {"name": "Global Aphasia", "description": "Severe impairment in all language modalities - comprehension, production, repetition, and naming.", "features": ["Severe comprehension deficit", "Non-fluent speech", "Poor repetition", "Severe naming difficulties"]}, "ISOLATION": {"name": "Isolation Syndrome", "description": "Rare condition with preserved repetition but severely impaired comprehension and spontaneous speech.", "features": ["Good repetition", "Poor comprehension", "Limited spontaneous speech", "Echolalia"]}, "TRANSSENSORY": {"name": "Trans-cortical Sensory Aphasia", "description": "Fluent speech with good repetition but impaired comprehension, similar to Wernicke's but with preserved repetition.", "features": ["Fluent speech", "Good repetition", "Poor comprehension", "Semantic errors"]} } self.load_configuration() self.load_model() print(f"推理系統初始化完成,使用設備: {self.device}") def load_configuration(self): cfg_path = os.path.join(self.model_dir, "config.json") if os.path.exists(cfg_path): with open(cfg_path, "r", encoding="utf-8") as f: cfg = json.load(f) self.aphasia_types_mapping = cfg.get("aphasia_types_mapping", { "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2, "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5, "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8 }) self.num_labels = cfg.get("num_labels", 9) self.model_name = cfg.get("model_name", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext") else: self.aphasia_types_mapping = { "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2, "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5, "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8 } self.num_labels = 9 self.model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" self.id_to_aphasia_type = {v: k for k, v in self.aphasia_types_mapping.items()} def load_model(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir, use_fast=True) # pad token fix if self.tokenizer.pad_token is None: if self.tokenizer.eos_token is not None: self.tokenizer.pad_token = self.tokenizer.eos_token elif self.tokenizer.unk_token is not None: self.tokenizer.pad_token = self.tokenizer.unk_token else: self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # optional added tokens add_path = os.path.join(self.model_dir, "added_tokens.json") if os.path.exists(add_path): with open(add_path, "r", encoding="utf-8") as f: data = json.load(f) tokens = list(data.keys()) if isinstance(data, dict) else data if tokens: self.tokenizer.add_tokens(tokens) self.config = ModelConfig() self.config.model_name = self.model_name self.model = StableAphasiaClassifier(self.config, self.num_labels) self.model.bert.resize_token_embeddings(len(self.tokenizer)) model_path = os.path.join(self.model_dir, "pytorch_model.bin") if not os.path.exists(model_path): raise FileNotFoundError(f"模型權重文件不存在: {model_path}") state = torch.load(model_path, map_location=self.device) self.model.load_state_dict(state) # (once) self.model.to(self.device) self.model.eval() # ---------- helpers ---------- def _dur_to_float(self, d) -> float: """Robustly parse duration from various shapes: - number - ["word", ms] - [start, end] - {"dur": ms} (future-proof) """ if isinstance(d, (int, float)): return float(d) if isinstance(d, list): if len(d) == 2: # ["word", 300] or [start, end] a, b = d[0], d[1] # case 1: word + ms if isinstance(a, str) and isinstance(b, (int, float)): return float(b) # case 2: start, end if isinstance(a, (int, float)) and isinstance(b, (int, float)): return float(b) - float(a) if isinstance(d, dict): for k in ("dur", "duration", "ms"): if k in d and isinstance(d[k], (int, float)): return float(d[k]) return 0.0 def _extract_prosodic_features(self, durations, tokens): vals = [] for d in durations: vals.append(self._dur_to_float(d)) vals = [v for v in vals if v > 0] if not vals: return [0.0] * self.config.prosody_dim features = [ float(np.mean(vals)), float(np.std(vals)), float(np.median(vals)), float(len([v for v in vals if v > (np.mean(vals) * 1.5)])), ] while len(features) < self.config.prosody_dim: features.append(0.0) return features[:self.config.prosody_dim] def _align_features(self, tokens, pos_ids, grammar_ids, durations, encoded): # map subtoken -> original token index subtoken_to_token = [] for idx, tok in enumerate(tokens): subtoks = self.tokenizer.tokenize(tok) subtoken_to_token.extend([idx] * max(1, len(subtoks))) aligned_pos = [0] # [CLS] aligned_grammar = [[0, 0, 0]] # [CLS] aligned_durations = [0.0] # [CLS] # reserve last slot for [SEP] max_body = self.config.max_length - 2 for st_idx in range(max_body): if st_idx < len(subtoken_to_token): orig = subtoken_to_token[st_idx] aligned_pos.append(pos_ids[orig] if orig < len(pos_ids) else 0) aligned_grammar.append(grammar_ids[orig] if orig < len(grammar_ids) else [0, 0, 0]) aligned_durations.append(self._dur_to_float(durations[orig]) if orig < len(durations) else 0.0) else: aligned_pos.append(0) aligned_grammar.append([0, 0, 0]) aligned_durations.append(0.0) aligned_pos.append(0) # [SEP] aligned_grammar.append([0, 0, 0]) # [SEP] aligned_durations.append(0.0) # [SEP] return aligned_pos, aligned_grammar, aligned_durations def preprocess_sentence(self, sentence_data: dict) -> Optional[dict]: all_tokens, all_pos, all_grammar, all_durations = [], [], [], [] for d_idx, dialogue in enumerate(sentence_data.get("dialogues", [])): if d_idx > 0: all_tokens.append("[DIALOGUE]") all_pos.append(0) all_grammar.append([0, 0, 0]) all_durations.append(0.0) for par in dialogue.get("PAR", []): if "tokens" in par and par["tokens"]: toks = par["tokens"] pos_ids = par.get("word_pos_ids", [0] * len(toks)) gra_ids = par.get("word_grammar_ids", [[0, 0, 0]] * len(toks)) durs = par.get("word_durations", [0.0] * len(toks)) all_tokens.extend(toks) all_pos.extend(pos_ids) all_grammar.extend(gra_ids) all_durations.extend(durs) if not all_tokens: return None text = " ".join(all_tokens) enc = self.tokenizer(text, max_length=self.config.max_length, padding="max_length", truncation=True, return_tensors="pt") aligned_pos, aligned_gra, aligned_dur = self._align_features( all_tokens, all_pos, all_grammar, all_durations, enc ) prosody = self._extract_prosodic_features(all_durations, all_tokens) prosody_tensor = torch.tensor(prosody).unsqueeze(0).repeat(self.config.max_length, 1) return { "input_ids": enc["input_ids"].squeeze(0), "attention_mask": enc["attention_mask"].squeeze(0), "word_pos_ids": torch.tensor(aligned_pos, dtype=torch.long), "word_grammar_ids": torch.tensor(aligned_gra, dtype=torch.long), "word_durations": torch.tensor(aligned_dur, dtype=torch.float), "prosody_features": prosody_tensor.float(), "sentence_id": sentence_data.get("sentence_id", "unknown"), "original_tokens": all_tokens, "text": text } def predict_single(self, sentence_data: dict) -> dict: proc = self.preprocess_sentence(sentence_data) if proc is None: return {"error": "無法處理輸入數據", "sentence_id": sentence_data.get("sentence_id", "unknown")} inp = { "input_ids": proc["input_ids"].unsqueeze(0).to(self.device), "attention_mask": proc["attention_mask"].unsqueeze(0).to(self.device), "word_pos_ids": proc["word_pos_ids"].unsqueeze(0).to(self.device), "word_grammar_ids": proc["word_grammar_ids"].unsqueeze(0).to(self.device), "word_durations": proc["word_durations"].unsqueeze(0).to(self.device), "prosody_features": proc["prosody_features"].unsqueeze(0).to(self.device), } with torch.no_grad(): out = self.model(**inp) logits = out["logits"] probs = F.softmax(logits, dim=1).cpu().numpy()[0] pred_id = int(np.argmax(probs)) sev = out["severity_pred"].cpu().numpy()[0] flu = float(out["fluency_pred"].cpu().numpy()[0][0]) pred_type = self.id_to_aphasia_type[pred_id] conf = float(probs[pred_id]) dist = {} for a_type, t_id in self.aphasia_types_mapping.items(): dist[a_type] = {"probability": float(probs[t_id]), "percentage": f"{probs[t_id]*100:.2f}%"} sorted_dist = dict(sorted(dist.items(), key=lambda x: x[1]["probability"], reverse=True)) return { "sentence_id": proc["sentence_id"], "input_text": proc["text"], "original_tokens": proc["original_tokens"], "prediction": { "predicted_class": pred_type, "confidence": conf, "confidence_percentage": f"{conf*100:.2f}%" }, "class_description": self.aphasia_descriptions.get(pred_type, { "name": pred_type, "description": "Description not available", "features": [] }), "probability_distribution": sorted_dist, "additional_predictions": { "severity_distribution": { "level_0": float(sev[0]), "level_1": float(sev[1]), "level_2": float(sev[2]), "level_3": float(sev[3]) }, "predicted_severity_level": int(np.argmax(sev)), "fluency_score": flu, "fluency_rating": "High" if flu > 0.7 else ("Medium" if flu > 0.4 else "Low"), } } def predict_batch(self, input_file: str, output_file: Optional[str] = None) -> Dict: with open(input_file, "r", encoding="utf-8") as f: data = json.load(f) sentences = data.get("sentences", []) results = [] print(f"開始處理 {len(sentences)} 個句子...") for i, s in enumerate(sentences): print(f"處理第 {i+1}/{len(sentences)} 個句子...") results.append(self.predict_single(s)) summary = self._generate_summary(results) final = {"summary": summary, "total_sentences": len(results), "predictions": results} if output_file: with open(output_file, "w", encoding="utf-8") as f: json.dump(final, f, ensure_ascii=False, indent=2) print(f"結果已保存到: {output_file}") return final def _generate_summary(self, results: List[dict]) -> dict: if not results: return {} class_counts = defaultdict(int) confs, flus = [], [] sev_counts = defaultdict(int) for r in results: if "error" in r: continue c = r["prediction"]["predicted_class"] class_counts[c] += 1 confs.append(r["prediction"]["confidence"]) flus.append(r["additional_predictions"]["fluency_score"]) sev_counts[r["additional_predictions"]["predicted_severity_level"]] += 1 avg_conf = float(np.mean(confs)) if confs else 0.0 avg_flu = float(np.mean(flus)) if flus else 0.0 return { "classification_distribution": dict(class_counts), "classification_percentages": {k: f"{v/len(results)*100:.1f}%" for k, v in class_counts.items()}, "average_confidence": f"{avg_conf:.3f}", "average_fluency_score": f"{avg_flu:.3f}", "severity_distribution": dict(sev_counts), "confidence_statistics": {} if not confs else { "mean": f"{np.mean(confs):.3f}", "std": f"{np.std(confs):.3f}", "min": f"{np.min(confs):.3f}", "max": f"{np.max(confs):.3f}", }, "most_common_prediction": max(class_counts.items(), key=lambda x: x[1])[0] if class_counts else "None", } def generate_detailed_report(self, results: List[dict], output_dir: str = "./inference_results"): os.makedirs(output_dir, exist_ok=True) rows = [] for r in results: if "error" in r: continue row = { "sentence_id": r["sentence_id"], "predicted_class": r["prediction"]["predicted_class"], "confidence": r["prediction"]["confidence"], "class_name": r["class_description"]["name"], "severity_level": r["additional_predictions"]["predicted_severity_level"], "fluency_score": r["additional_predictions"]["fluency_score"], "fluency_rating": r["additional_predictions"]["fluency_rating"], "input_text": r["input_text"], } for a_type, info in r["probability_distribution"].items(): row[f"prob_{a_type}"] = info["probability"] rows.append(row) if not rows: return None df = pd.DataFrame(rows) df.to_csv(os.path.join(output_dir, "detailed_predictions.csv"), index=False, encoding="utf-8") summary_stats = { "total_predictions": int(len(rows)), "class_distribution": df["predicted_class"].value_counts().to_dict(), "average_confidence": float(df["confidence"].mean()), "confidence_std": float(df["confidence"].std()), "average_fluency": float(df["fluency_score"].mean()), "fluency_std": float(df["fluency_score"].std()), "severity_distribution": df["severity_level"].value_counts().to_dict(), } with open(os.path.join(output_dir, "summary_statistics.json"), "w", encoding="utf-8") as f: json.dump(summary_stats, f, ensure_ascii=False, indent=2) print(f"詳細報告已生成並保存到: {output_dir}") return df # ========================= # Convenience: run directly or from pipeline # ========================= def predict_from_chajson(model_dir: str, chajson_path: str, output_file: Optional[str] = None) -> Dict: """ Convenience entry: - Accepts the JSON produced by cha_json.py - If it contains 'sentences', runs per-sentence like before - If it only contains 'text_all', creates a single pseudo-sentence """ with open(chajson_path, "r", encoding="utf-8") as f: data = json.load(f) inf = AphasiaInferenceSystem(model_dir) # If there are sentences, use the full path if data.get("sentences"): return inf.predict_batch(chajson_path, output_file=output_file) # Else, fall back to a single synthetic sentence using text_all text_all = data.get("text_all", "") fake = { "sentences": [{ "sentence_id": "S1", "dialogues": [{ "INV": [], "PAR": [{"tokens": text_all.split(), "word_pos_ids": [0]*len(text_all.split()), "word_grammar_ids": [[0,0,0]]*len(text_all.split()), "word_durations": [0.0]*len(text_all.split())}] }] }] } tmp_path = chajson_path + "._synthetic.json" with open(tmp_path, "w", encoding="utf-8") as f: json.dump(fake, f, ensure_ascii=False, indent=2) out = inf.predict_batch(tmp_path, output_file=output_file) try: os.remove(tmp_path) except Exception: pass return out def format_result(pred: dict, style: str = "json") -> str: """Back-compat formatter. 'pred' is the dict returned by predict_*.""" if style == "json": return json.dumps(pred, ensure_ascii=False, indent=2) # simple text summary if isinstance(pred, dict) and "summary" in pred: s = pred["summary"] lines = [ f"Total sentences: {pred.get('total_sentences', 0)}", f"Avg confidence: {s.get('average_confidence', 'N/A')}", f"Avg fluency: {s.get('average_fluency_score', 'N/A')}", f"Most common: {s.get('most_common_prediction', 'N/A')}", ] return "\n".join(lines) return str(pred) # ---------- CLI ---------- def main(): import argparse p = argparse.ArgumentParser(description="失語症分類推理系統") p.add_argument("--model_dir", type=str, required=False, default="./adaptive_aphasia_model", help="訓練好的模型目錄路徑") p.add_argument("--input_file", type=str, required=True, help="輸入JSON文件(cha_json 的輸出)") p.add_argument("--output_file", type=str, default="./aphasia_predictions.json", help="輸出JSON文件路徑") p.add_argument("--report_dir", type=str, default="./inference_results", help="詳細報告輸出目錄") p.add_argument("--generate_report", action="store_true", help="是否生成詳細的CSV報告") args = p.parse_args() try: print("正在初始化推理系統...") sys = AphasiaInferenceSystem(args.model_dir) print("開始執行批次預測...") results = sys.predict_batch(args.input_file, args.output_file) if args.generate_report: print("生成詳細報告...") sys.generate_detailed_report(results["predictions"], args.report_dir) print("\n=== 預測摘要 ===") s = results["summary"] print(f"總句子數: {results['total_sentences']}") print(f"平均信心度: {s.get('average_confidence', 'N/A')}") print(f"平均流利度: {s.get('average_fluency_score', 'N/A')}") print(f"最常見預測: {s.get('most_common_prediction', 'N/A')}") print("\n類別分佈:") for name, count in s.get("classification_distribution", {}).items(): pct = s.get("classification_percentages", {}).get(name, "0%") print(f" {name}: {count} ({pct})") print(f"\n結果已保存到: {args.output_file}") except Exception as e: print(f"錯誤: {str(e)}") import traceback; traceback.print_exc() if __name__ == "__main__": main()