# -*- coding: utf-8 -*- """ Advanced Multi-Modal Aphasia Classification System With Adaptive Learning Rate and Comprehensive Reporting """ import re import json import torch import torch.nn as nn import torch.nn.functional as F import time import datetime import numpy as np import os import random import csv import math from collections import Counter, defaultdict from typing import Dict, List, Optional, Tuple, Union from dataclasses import dataclass import torch.optim as optim from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset from transformers import ( AutoTokenizer, AutoModel, AutoConfig, TrainingArguments, Trainer, TrainerCallback, EarlyStoppingCallback, get_cosine_schedule_with_warmup, default_data_collator, set_seed ) import seaborn as sns import matplotlib.pyplot as plt import pandas as pd from sklearn.metrics import ( accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, classification_report, roc_auc_score ) from sklearn.model_selection import StratifiedKFold import gc from scipy import stats # Environment setup for stability os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["TORCH_USE_CUDA_DSA"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" json_file = '/workspace/SH001/aphasia_data_augmented.json' # Set seeds for reproducibility def set_all_seeds(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) os.environ['PYTHONHASHSEED'] = str(seed) set_all_seeds(42) # Configuration @dataclass class ModelConfig: # Model architecture model_name: str = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" max_length: int = 512 hidden_size: int = 768 # Feature dimensions pos_vocab_size: int = 150 pos_emb_dim: int = 64 grammar_dim: int = 3 grammar_hidden_dim: int = 64 duration_hidden_dim: int = 128 prosody_dim: int = 32 # Multi-head attention num_attention_heads: int = 8 attention_dropout: float = 0.3 # Classification head classifier_hidden_dims: List[int] = None dropout_rate: float = 0.3 activation_fn: str = "tanh" # Training learning_rate: float = 5e-4 weight_decay: float = 0.01 warmup_ratio: float = 0.1 batch_size: int = 10 num_epochs: int = 500 gradient_accumulation_steps: int = 4 # Adaptive Learning Rate Parameters adaptive_lr: bool = True lr_patience: int = 3 # Patience for learning rate adjustment lr_factor: float = 0.8 # Factor to multiply learning rate lr_increase_factor: float = 1.2 # Factor to increase learning rate min_lr: float = 1e-6 max_lr: float = 1e-3 oscillation_amplitude: float = 0.1 # For sinusoidal oscillation # Advanced techniques use_focal_loss: bool = True focal_alpha: float = 1.0 focal_gamma: float = 2.0 use_mixup: bool = False mixup_alpha: float = 0.2 use_label_smoothing: bool = True label_smoothing: float = 0.1 def __post_init__(self): if self.classifier_hidden_dims is None: self.classifier_hidden_dims = [512, 256] # Utility functions def log_message(message): timestamp = datetime.datetime.now().isoformat() full_message = f"{timestamp}: {message}" log_file = "./training_log.txt" with open(log_file, "a", encoding="utf-8") as f: f.write(full_message + "\n") print(full_message, flush=True) def clear_memory(): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def normalize_type(t): return t.strip().upper() if isinstance(t, str) else t # Adaptive Learning Rate Scheduler class AdaptiveLearningRateScheduler: """智能學習率調度器,結合多種策略""" def __init__(self, optimizer, config: ModelConfig, total_steps: int): self.optimizer = optimizer self.config = config self.total_steps = total_steps # 歷史記錄 self.loss_history = [] self.f1_history = [] self.accuracy_history = [] self.lr_history = [] # 狀態追蹤 self.plateau_counter = 0 self.best_f1 = 0.0 self.best_loss = float('inf') self.step_count = 0 # 初始學習率 self.base_lr = config.learning_rate self.current_lr = self.base_lr log_message(f"Adaptive LR Scheduler initialized with base_lr={self.base_lr}") def calculate_slope(self, values, window=3): """計算近期數值的斜率""" if len(values) < window: return 0.0 recent_values = values[-window:] x = np.arange(len(recent_values)) slope, _, _, _, _ = stats.linregress(x, recent_values) return slope def exponential_adjustment(self, current_value, target_value, base_factor=1.1): """指數調整函數""" ratio = current_value / target_value if target_value != 0 else 1.0 factor = math.exp(-ratio) * base_factor return factor def logarithmic_adjustment(self, current_value, threshold=0.1): """對數調整函數""" if current_value <= 0: return 1.0 factor = math.log(1 + current_value / threshold) return max(0.5, min(2.0, factor)) def sinusoidal_oscillation(self, step, amplitude=None): """正弦波動調整""" if amplitude is None: amplitude = self.config.oscillation_amplitude # 基於步數的正弦波動 phase = 2 * math.pi * step / (self.total_steps / 4) # 4個週期 oscillation = 1 + amplitude * math.sin(phase) return oscillation def cosine_decay(self, step): """餘弦衰減""" progress = step / self.total_steps decay = 0.5 * (1 + math.cos(math.pi * progress)) return decay def adaptive_lr_calculation(self, current_loss, current_f1, current_acc): """智能學習率計算""" # 記錄歷史 self.loss_history.append(current_loss) self.f1_history.append(current_f1) self.accuracy_history.append(current_acc) # 計算斜率 loss_slope = self.calculate_slope(self.loss_history) f1_slope = self.calculate_slope(self.f1_history) acc_slope = self.calculate_slope(self.accuracy_history) # 基礎學習率調整因子 adjustment_factor = 1.0 # 1. 基於Loss斜率的調整 if abs(loss_slope) < 0.001: # Loss plateau log_message(f"Loss plateau detected (slope: {loss_slope:.6f})") # 指數增加學習率 exp_factor = self.exponential_adjustment(abs(loss_slope), 0.01, 1.15) adjustment_factor *= exp_factor elif current_loss > 2.0: # Loss太高 log_message(f"High loss detected: {current_loss:.4f}") # 對數調整 log_factor = self.logarithmic_adjustment(current_loss, 1.0) adjustment_factor *= log_factor # 2. 基於F1分數的調整 if current_f1 < 0.3: # F1太低 log_message(f"Low F1 detected: {current_f1:.4f}") # 指數增加學習率 exp_factor = self.exponential_adjustment(0.3, current_f1, 1.2) adjustment_factor *= exp_factor elif abs(f1_slope) < 0.001: # F1 plateau log_message(f"F1 plateau detected (slope: {f1_slope:.6f})") adjustment_factor *= 1.1 # 3. 添加正弦波動性 sin_factor = self.sinusoidal_oscillation(self.step_count) # 4. 添加餘弦衰減 cos_factor = self.cosine_decay(self.step_count) # 綜合調整 final_factor = adjustment_factor * sin_factor * (0.3 + 0.7 * cos_factor) # 計算新的學習率 new_lr = self.current_lr * final_factor # 限制學習率範圍 new_lr = max(self.config.min_lr, min(self.config.max_lr, new_lr)) # 更新學習率 if abs(new_lr - self.current_lr) > 1e-7: # 只有變化足夠大才更新 self.current_lr = new_lr for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr log_message(f"Learning rate adjusted: {new_lr:.2e} (factor: {final_factor:.3f})") log_message(f" - Loss slope: {loss_slope:.6f}, F1 slope: {f1_slope:.6f}") log_message(f" - Sin factor: {sin_factor:.3f}, Cos factor: {cos_factor:.3f}") self.lr_history.append(self.current_lr) self.step_count += 1 return self.current_lr # Training History Tracker class TrainingHistoryTracker: """訓練歷史記錄器""" def __init__(self): self.history = { 'epoch': [], 'train_loss': [], 'eval_loss': [], 'train_accuracy': [], 'eval_accuracy': [], 'train_f1': [], 'eval_f1': [], 'learning_rate': [], 'train_precision': [], 'eval_precision': [], 'train_recall': [], 'eval_recall': [] } def update(self, epoch, metrics): """更新歷史記錄""" self.history['epoch'].append(epoch) for key, value in metrics.items(): if key in self.history: self.history[key].append(value) def save_history(self, output_dir): """保存歷史記錄""" df = pd.DataFrame(self.history) df.to_csv(os.path.join(output_dir, "training_history.csv"), index=False) return df def plot_training_curves(self, output_dir): """繪製訓練曲線""" if not self.history['epoch']: return # 設置圖表樣式 plt.style.use('seaborn-v0_8') fig, axes = plt.subplots(2, 3, figsize=(18, 12)) epochs = self.history['epoch'] # 1. Loss曲線 axes[0, 0].plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2) axes[0, 0].plot(epochs, self.history['eval_loss'], 'r-', label='Eval Loss', linewidth=2) axes[0, 0].set_title('Loss Over Time', fontsize=14, fontweight='bold') axes[0, 0].set_xlabel('Epoch') axes[0, 0].set_ylabel('Loss') axes[0, 0].legend() axes[0, 0].grid(True, alpha=0.3) # 2. 準確率曲線 axes[0, 1].plot(epochs, self.history['train_accuracy'], 'b-', label='Train Accuracy', linewidth=2) axes[0, 1].plot(epochs, self.history['eval_accuracy'], 'r-', label='Eval Accuracy', linewidth=2) axes[0, 1].set_title('Accuracy Over Time', fontsize=14, fontweight='bold') axes[0, 1].set_xlabel('Epoch') axes[0, 1].set_ylabel('Accuracy') axes[0, 1].legend() axes[0, 1].grid(True, alpha=0.3) # 3. F1分數曲線 axes[0, 2].plot(epochs, self.history['train_f1'], 'b-', label='Train F1', linewidth=2) axes[0, 2].plot(epochs, self.history['eval_f1'], 'r-', label='Eval F1', linewidth=2) axes[0, 2].set_title('F1 Score Over Time', fontsize=14, fontweight='bold') axes[0, 2].set_xlabel('Epoch') axes[0, 2].set_ylabel('F1 Score') axes[0, 2].legend() axes[0, 2].grid(True, alpha=0.3) # 4. 學習率曲線 axes[1, 0].plot(epochs, self.history['learning_rate'], 'g-', linewidth=2) axes[1, 0].set_title('Learning Rate Over Time', fontsize=14, fontweight='bold') axes[1, 0].set_xlabel('Epoch') axes[1, 0].set_ylabel('Learning Rate') axes[1, 0].set_yscale('log') axes[1, 0].grid(True, alpha=0.3) # 5. Precision曲線 axes[1, 1].plot(epochs, self.history['train_precision'], 'b-', label='Train Precision', linewidth=2) axes[1, 1].plot(epochs, self.history['eval_precision'], 'r-', label='Eval Precision', linewidth=2) axes[1, 1].set_title('Precision Over Time', fontsize=14, fontweight='bold') axes[1, 1].set_xlabel('Epoch') axes[1, 1].set_ylabel('Precision') axes[1, 1].legend() axes[1, 1].grid(True, alpha=0.3) # 6. Recall曲線 axes[1, 2].plot(epochs, self.history['train_recall'], 'b-', label='Train Recall', linewidth=2) axes[1, 2].plot(epochs, self.history['eval_recall'], 'r-', label='Eval Recall', linewidth=2) axes[1, 2].set_title('Recall Over Time', fontsize=14, fontweight='bold') axes[1, 2].set_xlabel('Epoch') axes[1, 2].set_ylabel('Recall') axes[1, 2].legend() axes[1, 2].grid(True, alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(output_dir, "training_curves.png"), dpi=300, bbox_inches='tight') plt.close() # Focal loss implementation class FocalLoss(nn.Module): def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'): super().__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() else: return focal_loss # Stable positional encoding class StablePositionalEncoding(nn.Module): """Simplified but stable positional encoding""" def __init__(self, d_model: int, max_len: int = 5000): super().__init__() self.d_model = d_model # Traditional sinusoidal encoding pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) # Simple learnable component self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01) def forward(self, x): seq_len = x.size(1) sinusoidal = self.pe[:, :seq_len, :].to(x.device) learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1) return x + 0.1 * (sinusoidal + learnable) # Stable multi-head attention class StableMultiHeadAttention(nn.Module): """Stable multi-head attention for feature fusion""" def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3): super().__init__() self.num_heads = num_heads self.feature_dim = feature_dim self.head_dim = feature_dim // num_heads assert feature_dim % num_heads == 0 self.query = nn.Linear(feature_dim, feature_dim) self.key = nn.Linear(feature_dim, feature_dim) self.value = nn.Linear(feature_dim, feature_dim) self.dropout = nn.Dropout(dropout) self.output_proj = nn.Linear(feature_dim, feature_dim) self.layer_norm = nn.LayerNorm(feature_dim) def forward(self, x, mask=None): batch_size, seq_len, _ = x.size() Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) if mask is not None: if mask.dim() == 2: mask = mask.unsqueeze(1).unsqueeze(1) scores.masked_fill_(mask == 0, -1e9) attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) context = torch.matmul(attn_weights, V) context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.feature_dim) output = self.output_proj(context) return self.layer_norm(output + x) # Stable linguistic feature extractor class StableLinguisticFeatureExtractor(nn.Module): """Stable linguistic feature processing""" def __init__(self, config: ModelConfig): super().__init__() self.config = config # POS embeddings self.pos_embedding = nn.Embedding(config.pos_vocab_size, config.pos_emb_dim, padding_idx=0) self.pos_attention = StableMultiHeadAttention(config.pos_emb_dim, num_heads=4) # Grammar feature processing self.grammar_projection = nn.Sequential( nn.Linear(config.grammar_dim, config.grammar_hidden_dim), nn.Tanh(), nn.LayerNorm(config.grammar_hidden_dim), nn.Dropout(config.dropout_rate * 0.3) ) # Duration processing self.duration_projection = nn.Sequential( nn.Linear(1, config.duration_hidden_dim), nn.Tanh(), nn.LayerNorm(config.duration_hidden_dim) ) # Prosody processing self.prosody_projection = nn.Sequential( nn.Linear(config.prosody_dim, config.prosody_dim), nn.ReLU(), nn.LayerNorm(config.prosody_dim) ) # Feature fusion total_feature_dim = (config.pos_emb_dim + config.grammar_hidden_dim + config.duration_hidden_dim + config.prosody_dim) self.feature_fusion = nn.Sequential( nn.Linear(total_feature_dim, total_feature_dim // 2), nn.Tanh(), nn.LayerNorm(total_feature_dim // 2), nn.Dropout(config.dropout_rate) ) def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask): batch_size, seq_len = pos_ids.size() # Process POS features with clamping pos_ids_clamped = pos_ids.clamp(0, self.config.pos_vocab_size - 1) pos_embeds = self.pos_embedding(pos_ids_clamped) pos_features = self.pos_attention(pos_embeds, attention_mask) # Process grammar features grammar_features = self.grammar_projection(grammar_ids.float()) # Process duration features duration_features = self.duration_projection(durations.unsqueeze(-1).float()) # Process prosodic features prosody_features = self.prosody_projection(prosody_features.float()) # Combine features combined_features = torch.cat([ pos_features, grammar_features, duration_features, prosody_features ], dim=-1) # Feature fusion fused_features = self.feature_fusion(combined_features) # Global pooling mask_expanded = attention_mask.unsqueeze(-1).float() pooled_features = torch.sum(fused_features * mask_expanded, dim=1) / torch.sum(mask_expanded, dim=1) return pooled_features # Main classifier with stability improvements class StableAphasiaClassifier(nn.Module): """Stable aphasia classification model""" def __init__(self, config: ModelConfig, num_labels: int): super().__init__() self.config = config self.num_labels = num_labels # Pre-trained model self.bert = AutoModel.from_pretrained(config.model_name) self.bert_config = self.bert.config # Freeze embeddings for stability for param in self.bert.embeddings.parameters(): param.requires_grad = False # Positional encoding self.positional_encoder = StablePositionalEncoding( d_model=self.bert_config.hidden_size, max_len=config.max_length ) # Linguistic feature extractor self.linguistic_extractor = StableLinguisticFeatureExtractor(config) # Calculate dimensions bert_dim = self.bert_config.hidden_size linguistic_dim = (config.pos_emb_dim + config.grammar_hidden_dim + config.duration_hidden_dim + config.prosody_dim) // 2 # Feature fusion self.feature_fusion = nn.Sequential( nn.Linear(bert_dim + linguistic_dim, bert_dim), nn.LayerNorm(bert_dim), nn.Tanh(), nn.Dropout(config.dropout_rate) ) # Classifier self.classifier = self._build_classifier(bert_dim, num_labels) # Multi-task heads (simplified) self.severity_head = nn.Sequential( nn.Linear(bert_dim, 4), nn.Softmax(dim=-1) ) self.fluency_head = nn.Sequential( nn.Linear(bert_dim, 1), nn.Sigmoid() ) def _build_classifier(self, input_dim: int, num_labels: int): layers = [] current_dim = input_dim for hidden_dim in self.config.classifier_hidden_dims: layers.extend([ nn.Linear(current_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh(), nn.Dropout(self.config.dropout_rate) ]) current_dim = hidden_dim layers.append(nn.Linear(current_dim, num_labels)) return nn.Sequential(*layers) def forward(self, input_ids, attention_mask, labels=None, word_pos_ids=None, word_grammar_ids=None, word_durations=None, prosody_features=None, **kwargs): # BERT encoding bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) sequence_output = bert_outputs.last_hidden_state # Apply positional encoding position_enhanced = self.positional_encoder(sequence_output) # Attention pooling pooled_output = self._attention_pooling(position_enhanced, attention_mask) # Process linguistic features if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]): if prosody_features is None: batch_size, seq_len = input_ids.size() prosody_features = torch.zeros( batch_size, seq_len, self.config.prosody_dim, device=input_ids.device ) linguistic_features = self.linguistic_extractor( word_pos_ids, word_grammar_ids, word_durations, prosody_features, attention_mask ) else: linguistic_features = torch.zeros( input_ids.size(0), (self.config.pos_emb_dim + self.config.grammar_hidden_dim + self.config.duration_hidden_dim + self.config.prosody_dim) // 2, device=input_ids.device ) # Feature fusion combined_features = torch.cat([pooled_output, linguistic_features], dim=1) fused_features = self.feature_fusion(combined_features) # Predictions logits = self.classifier(fused_features) severity_pred = self.severity_head(fused_features) fluency_pred = self.fluency_head(fused_features) # Loss computation loss = None if labels is not None: loss = self._compute_loss(logits, labels) return { "logits": logits, "severity_pred": severity_pred, "fluency_pred": fluency_pred, "loss": loss } def _attention_pooling(self, sequence_output, attention_mask): """Attention-based pooling""" attention_weights = torch.softmax( torch.sum(sequence_output, dim=-1, keepdim=True), dim=1 ) attention_weights = attention_weights * attention_mask.unsqueeze(-1).float() attention_weights = attention_weights / (torch.sum(attention_weights, dim=1, keepdim=True) + 1e-9) pooled = torch.sum(sequence_output * attention_weights, dim=1) return pooled def _compute_loss(self, logits, labels): if self.config.use_focal_loss: focal_loss = FocalLoss( alpha=self.config.focal_alpha, gamma=self.config.focal_gamma, reduction='mean' ) return focal_loss(logits, labels) else: if self.config.use_label_smoothing: return F.cross_entropy( logits, labels, label_smoothing=self.config.label_smoothing ) else: return F.cross_entropy(logits, labels) # Stable dataset class class StableAphasiaDataset(Dataset): """Stable dataset with simplified processing""" def __init__(self, sentences, tokenizer, aphasia_types_mapping, config: ModelConfig): self.samples = [] self.tokenizer = tokenizer self.config = config self.aphasia_types_mapping = aphasia_types_mapping # Add special tokens special_tokens = ["[DIALOGUE]", "[TURN]", "[PAUSE]", "[REPEAT]", "[HESITATION]"] tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) for idx, item in enumerate(sentences): sentence_id = item.get("sentence_id", f"S{idx}") aphasia_type = normalize_type(item.get("aphasia_type", "")) if aphasia_type not in aphasia_types_mapping: log_message(f"Skipping Sentence {sentence_id}: Invalid aphasia type '{aphasia_type}'") continue self._process_sentence(item, sentence_id, aphasia_type) if not self.samples: raise ValueError("No valid samples found in dataset!") log_message(f"Dataset created with {len(self.samples)} samples") self._print_class_distribution() def _process_sentence(self, item, sentence_id, aphasia_type): """Process sentence with stable approach""" all_tokens, all_pos, all_grammar, all_durations = [], [], [], [] for dialogue_idx, dialogue in enumerate(item.get("dialogues", [])): if dialogue_idx > 0: all_tokens.append("[DIALOGUE]") all_pos.append(0) all_grammar.append([0, 0, 0]) all_durations.append(0.0) for par in dialogue.get("PAR", []): if "tokens" in par and par["tokens"]: tokens = par["tokens"] pos_ids = par.get("word_pos_ids", [0] * len(tokens)) grammar_ids = par.get("word_grammar_ids", [[0, 0, 0]] * len(tokens)) durations = par.get("word_durations", [0.0] * len(tokens)) all_tokens.extend(tokens) all_pos.extend(pos_ids) all_grammar.extend(grammar_ids) all_durations.extend(durations) if not all_tokens: return # Create sample self._create_sample(all_tokens, all_pos, all_grammar, all_durations, sentence_id, aphasia_type) def _create_sample(self, tokens, pos_ids, grammar_ids, durations, sentence_id, aphasia_type): """Create training sample""" # Tokenize text = " ".join(tokens) encoded = self.tokenizer( text, max_length=self.config.max_length, padding="max_length", truncation=True, return_tensors="pt" ) # Align features aligned_pos, aligned_grammar, aligned_durations = self._align_features( tokens, pos_ids, grammar_ids, durations, encoded ) # Create prosody features prosody_features = self._extract_prosodic_features(durations, tokens) prosody_tensor = torch.tensor(prosody_features).unsqueeze(0).repeat( self.config.max_length, 1 ) label = self.aphasia_types_mapping[aphasia_type] sample = { "input_ids": encoded["input_ids"].squeeze(0), "attention_mask": encoded["attention_mask"].squeeze(0), "labels": torch.tensor(label, dtype=torch.long), "word_pos_ids": torch.tensor(aligned_pos, dtype=torch.long), "word_grammar_ids": torch.tensor(aligned_grammar, dtype=torch.long), "word_durations": torch.tensor(aligned_durations, dtype=torch.float), "prosody_features": prosody_tensor.float(), "sentence_id": sentence_id } self.samples.append(sample) def _align_features(self, tokens, pos_ids, grammar_ids, durations, encoded): """Align features with BERT subtokens""" subtoken_to_token = [] for token_idx, token in enumerate(tokens): subtokens = self.tokenizer.tokenize(token) subtoken_to_token.extend([token_idx] * len(subtokens)) aligned_pos = [0] # [CLS] aligned_grammar = [[0, 0, 0]] # [CLS] aligned_durations = [0.0] # [CLS] for subtoken_idx in range(1, self.config.max_length - 1): if subtoken_idx - 1 < len(subtoken_to_token): original_idx = subtoken_to_token[subtoken_idx - 1] aligned_pos.append(pos_ids[original_idx] if original_idx < len(pos_ids) else 0) aligned_grammar.append(grammar_ids[original_idx] if original_idx < len(grammar_ids) else [0, 0, 0]) raw = durations[original_idx] if original_idx < len(durations) else 0.0 if isinstance(raw, list) and (isinstance(raw[1], int) and isinstance(raw[0], int)): if len(raw) >= 2: duration_val = int(raw[1]) - int(raw[0]) else: duration_val = raw[0] else: duration_val = 0.0 aligned_durations.append(duration_val) else: aligned_pos.append(0) aligned_grammar.append([0, 0, 0]) aligned_durations.append(0.0) aligned_pos.append(0) # [SEP] aligned_grammar.append([0, 0, 0]) # [SEP] aligned_durations.append(0.0) # [SEP] return aligned_pos, aligned_grammar, aligned_durations def _extract_prosodic_features(self, durations, tokens): """Extract prosodic features""" if not durations: return [0.0] * self.config.prosody_dim valid_durations = [d for d in durations if isinstance(d, (int, float)) and d > 0] if not valid_durations: return [0.0] * self.config.prosody_dim features = [ np.mean(valid_durations), np.std(valid_durations), np.median(valid_durations), len([d for d in valid_durations if d > np.mean(valid_durations) * 1.5]) ] # Pad to prosody_dim while len(features) < self.config.prosody_dim: features.append(0.0) return features[:self.config.prosody_dim] def _print_class_distribution(self): """Print class distribution""" label_counts = Counter(sample["labels"].item() for sample in self.samples) reverse_mapping = {v: k for k, v in self.aphasia_types_mapping.items()} log_message("\nClass Distribution:") for label_id, count in sorted(label_counts.items()): class_name = reverse_mapping.get(label_id, f"Unknown_{label_id}") log_message(f" {class_name}: {count} samples") def __len__(self): return len(self.samples) def __getitem__(self, idx): return self.samples[idx] # Stable data collator def stable_collate_fn(batch): """Stable data collation""" if not batch or batch[0] is None: return None try: max_length = batch[0]["input_ids"].size(0) collated_batch = { "input_ids": torch.stack([item["input_ids"] for item in batch]), "attention_mask": torch.stack([item["attention_mask"] for item in batch]), "labels": torch.stack([item["labels"] for item in batch]), "sentence_ids": [item.get("sentence_id", "N/A") for item in batch], "word_pos_ids": torch.stack([item.get("word_pos_ids", torch.zeros(max_length, dtype=torch.long)) for item in batch]), "word_grammar_ids": torch.stack([item.get("word_grammar_ids", torch.zeros(max_length, 3, dtype=torch.long)) for item in batch]), "word_durations": torch.stack([item.get("word_durations", torch.zeros(max_length, dtype=torch.float)) for item in batch]), "prosody_features": torch.stack([item.get("prosody_features", torch.zeros(max_length, 32, dtype=torch.float)) for item in batch]) } return collated_batch except Exception as e: log_message(f"Collation error: {e}") return None # Enhanced Training callback with adaptive learning rate class AdaptiveTrainingCallback(TrainerCallback): """Enhanced training callback with adaptive learning rate and comprehensive tracking""" def __init__(self, config: ModelConfig, patience=5, min_delta=0.8): self.config = config self.patience = patience self.min_delta = min_delta self.best_metric = float('-inf') self.patience_counter = 0 # Learning rate scheduler self.lr_scheduler = None # History tracker self.history_tracker = TrainingHistoryTracker() # Metrics for current epoch self.current_train_metrics = {} self.current_eval_metrics = {} def on_train_begin(self, args, state, control, **kwargs): """Initialize learning rate scheduler""" if self.config.adaptive_lr: model = kwargs.get('model') optimizer = kwargs.get('optimizer') if optimizer and model: total_steps = state.max_steps if state.max_steps > 0 else len(kwargs.get('train_dataloader', [])) * args.num_train_epochs self.lr_scheduler = AdaptiveLearningRateScheduler(optimizer, self.config, total_steps) log_message("Adaptive learning rate scheduler initialized") def on_log(self, args, state, control, logs=None, **kwargs): """Capture training metrics""" if logs: # Store training metrics if 'train_loss' in logs: self.current_train_metrics['loss'] = logs['train_loss'] if 'learning_rate' in logs: self.current_train_metrics['lr'] = logs['learning_rate'] def on_evaluate(self, args, state, control, logs=None, **kwargs): """Handle evaluation and learning rate adjustment""" if logs is not None: current_metric = logs.get('eval_f1', 0) current_loss = logs.get('eval_loss', float('inf')) current_acc = logs.get('eval_accuracy', 0) # Store evaluation metrics self.current_eval_metrics = { 'loss': current_loss, 'f1': current_metric, 'accuracy': current_acc, 'precision': logs.get('eval_precision_macro', 0), 'recall': logs.get('eval_recall_macro', 0) } # Update history epoch_metrics = { 'train_loss': self.current_train_metrics.get('loss', 0), 'eval_loss': current_loss, 'train_accuracy': 0, # Will be computed separately if needed 'eval_accuracy': current_acc, 'train_f1': 0, # Will be computed separately if needed 'eval_f1': current_metric, 'learning_rate': self.current_train_metrics.get('lr', self.config.learning_rate), 'train_precision': 0, 'eval_precision': logs.get('eval_precision_macro', 0), 'train_recall': 0, 'eval_recall': logs.get('eval_recall_macro', 0) } self.history_tracker.update(state.epoch, epoch_metrics) # Adaptive learning rate adjustment if self.lr_scheduler and self.config.adaptive_lr: new_lr = self.lr_scheduler.adaptive_lr_calculation(current_loss, current_metric, current_acc) if current_acc > 0.84: log_message(f"Target accuracy reached ({current_acc:.2%}) → stopping and saving model") control.should_save = True control.should_training_stop = True return control # Early stopping logic if current_metric > self.best_metric + self.min_delta: self.best_metric = current_metric self.patience_counter = 0 log_message(f"New best F1 score: {current_metric:.4f}") else: self.patience_counter += 1 log_message(f"No improvement for {self.patience_counter} evaluations") if self.patience_counter >= self.patience: log_message("Early stopping triggered") control.should_training_stop = True clear_memory() def on_train_end(self, args, state, control, **kwargs): """Save training history at the end""" output_dir = args.output_dir self.history_tracker.save_history(output_dir) self.history_tracker.plot_training_curves(output_dir) log_message("Training history and curves saved") # Metrics computation def compute_comprehensive_metrics(pred): """Compute comprehensive evaluation metrics""" predictions = pred.predictions[0] if isinstance(pred.predictions, tuple) else pred.predictions labels = pred.label_ids preds = np.argmax(predictions, axis=1) acc = accuracy_score(labels, preds) f1_macro = f1_score(labels, preds, average='macro', zero_division=0) f1_weighted = f1_score(labels, preds, average='weighted', zero_division=0) precision_macro = precision_score(labels, preds, average='macro', zero_division=0) recall_macro = recall_score(labels, preds, average='macro', zero_division=0) # Per-class metrics f1_per_class = f1_score(labels, preds, average=None, zero_division=0) precision_per_class = precision_score(labels, preds, average=None, zero_division=0) recall_per_class = recall_score(labels, preds, average=None, zero_division=0) return { "accuracy": acc, "f1": f1_weighted, "f1_macro": f1_macro, "precision_macro": precision_macro, "recall_macro": recall_macro, "f1_std": np.std(f1_per_class), "precision_std": np.std(precision_per_class), "recall_std": np.std(recall_per_class) } # Enhanced analysis and visualization def generate_comprehensive_reports(trainer, eval_dataset, aphasia_types_mapping, tokenizer, output_dir): """Generate comprehensive analysis reports and visualizations""" log_message("Generating comprehensive reports...") model = trainer.model if hasattr(model, 'module'): model = model.module model.eval() device = next(model.parameters()).device predictions = [] true_labels = [] sentence_ids = [] severity_preds = [] fluency_preds = [] prediction_probs = [] # Evaluation dataloader = DataLoader(eval_dataset, batch_size=8, collate_fn=stable_collate_fn) with torch.no_grad(): for batch_idx, batch in enumerate(dataloader): if batch is None: continue # Move to device for key in ['input_ids', 'attention_mask', 'word_pos_ids', 'word_grammar_ids', 'word_durations', 'labels', 'prosody_features']: if key in batch: batch[key] = batch[key].to(device) outputs = model(**batch) logits = outputs["logits"] probs = F.softmax(logits, dim=1) preds = torch.argmax(logits, dim=1).cpu().numpy() predictions.extend(preds) true_labels.extend(batch["labels"].cpu().numpy()) sentence_ids.extend(batch["sentence_ids"]) severity_preds.extend(outputs["severity_pred"].cpu().numpy()) fluency_preds.extend(outputs["fluency_pred"].cpu().numpy()) prediction_probs.extend(probs.cpu().numpy()) # Analysis reverse_mapping = {v: k for k, v in aphasia_types_mapping.items()} # 1. 詳細預測結果 log_message("=== DETAILED PREDICTIONS (First 20) ===") for i in range(min(20, len(predictions))): true_type = reverse_mapping.get(true_labels[i], 'Unknown') pred_type = reverse_mapping.get(predictions[i], 'Unknown') severity_level = np.argmax(severity_preds[i]) fluency_score = fluency_preds[i][0] if isinstance(fluency_preds[i], np.ndarray) else fluency_preds[i] confidence = np.max(prediction_probs[i]) log_message(f"ID: {sentence_ids[i]} | True: {true_type} | Pred: {pred_type} | " f"Confidence: {confidence:.3f} | Severity: {severity_level} | Fluency: {fluency_score:.3f}") # 2. 混淆矩陣 cm = confusion_matrix(true_labels, predictions) # Enhanced confusion matrix plot plt.figure(figsize=(14, 12)) # Calculate percentages cm_percentage = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100 # Create annotation array annotations = np.empty_like(cm, dtype=object) for i in range(cm.shape[0]): for j in range(cm.shape[1]): annotations[i, j] = f'{cm[i, j]}\n({cm_percentage[i, j]:.1f}%)' sns.heatmap(cm, annot=annotations, fmt='', cmap="Blues", xticklabels=list(aphasia_types_mapping.keys()), yticklabels=list(aphasia_types_mapping.keys()), cbar_kws={'label': 'Count'}) plt.xlabel("Predicted Label", fontsize=12, fontweight='bold') plt.ylabel("True Label", fontsize=12, fontweight='bold') plt.title("Enhanced Confusion Matrix\n(Count and Percentage)", fontsize=14, fontweight='bold') plt.xticks(rotation=45, ha='right') plt.yticks(rotation=0) plt.tight_layout() plt.savefig(os.path.join(output_dir, "enhanced_confusion_matrix.png"), dpi=300, bbox_inches='tight') plt.close() # 3. 分類報告 all_label_ids = list(aphasia_types_mapping.values()) report_dict = classification_report( true_labels, predictions, labels=all_label_ids, target_names=list(aphasia_types_mapping.keys()), output_dict=True, zero_division=0 ) df_report = pd.DataFrame(report_dict).transpose() df_report.to_csv(os.path.join(output_dir, "comprehensive_classification_report.csv")) # 4. Per-class performance visualization class_names = list(aphasia_types_mapping.keys()) metrics_data = [] for i, class_name in enumerate(class_names): if class_name in report_dict: metrics_data.append({ 'Class': class_name, 'Precision': report_dict[class_name]['precision'], 'Recall': report_dict[class_name]['recall'], 'F1-Score': report_dict[class_name]['f1-score'], 'Support': report_dict[class_name]['support'] }) df_metrics = pd.DataFrame(metrics_data) df_metrics.to_csv(os.path.join(output_dir, "per_class_metrics.csv"), index=False) # Plot per-class performance fig, axes = plt.subplots(2, 2, figsize=(16, 12)) # Precision axes[0, 0].bar(df_metrics['Class'], df_metrics['Precision'], color='skyblue', alpha=0.8) axes[0, 0].set_title('Precision by Class', fontweight='bold') axes[0, 0].set_ylabel('Precision') axes[0, 0].tick_params(axis='x', rotation=45) axes[0, 0].grid(True, alpha=0.3) # Recall axes[0, 1].bar(df_metrics['Class'], df_metrics['Recall'], color='lightcoral', alpha=0.8) axes[0, 1].set_title('Recall by Class', fontweight='bold') axes[0, 1].set_ylabel('Recall') axes[0, 1].tick_params(axis='x', rotation=45) axes[0, 1].grid(True, alpha=0.3) # F1-Score axes[1, 0].bar(df_metrics['Class'], df_metrics['F1-Score'], color='lightgreen', alpha=0.8) axes[1, 0].set_title('F1-Score by Class', fontweight='bold') axes[1, 0].set_ylabel('F1-Score') axes[1, 0].tick_params(axis='x', rotation=45) axes[1, 0].grid(True, alpha=0.3) # Support axes[1, 1].bar(df_metrics['Class'], df_metrics['Support'], color='gold', alpha=0.8) axes[1, 1].set_title('Support by Class', fontweight='bold') axes[1, 1].set_ylabel('Support (Number of Samples)') axes[1, 1].tick_params(axis='x', rotation=45) axes[1, 1].grid(True, alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(output_dir, "per_class_performance.png"), dpi=300, bbox_inches='tight') plt.close() # 5. Prediction confidence distribution confidences = [np.max(prob) for prob in prediction_probs] correct_predictions = [pred == true for pred, true in zip(predictions, true_labels)] plt.figure(figsize=(12, 8)) # Separate correct and incorrect predictions correct_confidences = [conf for conf, correct in zip(confidences, correct_predictions) if correct] incorrect_confidences = [conf for conf, correct in zip(confidences, correct_predictions) if not correct] plt.hist(correct_confidences, bins=30, alpha=0.7, label='Correct Predictions', color='green', density=True) plt.hist(incorrect_confidences, bins=30, alpha=0.7, label='Incorrect Predictions', color='red', density=True) plt.xlabel('Prediction Confidence', fontsize=12) plt.ylabel('Density', fontsize=12) plt.title('Distribution of Prediction Confidence', fontsize=14, fontweight='bold') plt.legend() plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(output_dir, "confidence_distribution.png"), dpi=300, bbox_inches='tight') plt.close() # 6. 特徵分析 log_message("=== FEATURE ANALYSIS ===") avg_severity = np.mean(severity_preds, axis=0) avg_fluency = np.mean(fluency_preds) std_fluency = np.std(fluency_preds) log_message(f"Average Severity Distribution: {avg_severity}") log_message(f"Average Fluency Score: {avg_fluency:.3f} ± {std_fluency:.3f}") # 7. 詳細結果保存 results_df = pd.DataFrame({ 'sentence_id': sentence_ids, 'true_label': [reverse_mapping[label] for label in true_labels], 'predicted_label': [reverse_mapping[pred] for pred in predictions], 'prediction_confidence': confidences, 'correct_prediction': correct_predictions, 'severity_level': [np.argmax(severity) for severity in severity_preds], 'fluency_score': [fluency[0] if isinstance(fluency, np.ndarray) else fluency for fluency in fluency_preds] }) # Add probability columns for each class for i, class_name in enumerate(aphasia_types_mapping.keys()): results_df[f'prob_{class_name}'] = [prob[i] for prob in prediction_probs] results_df.to_csv(os.path.join(output_dir, "comprehensive_results.csv"), index=False) # 8. 統計摘要 summary_stats = { 'Overall Accuracy': accuracy_score(true_labels, predictions), 'Macro F1': f1_score(true_labels, predictions, average='macro'), 'Weighted F1': f1_score(true_labels, predictions, average='weighted'), 'Macro Precision': precision_score(true_labels, predictions, average='macro'), 'Macro Recall': recall_score(true_labels, predictions, average='macro'), 'Average Confidence': np.mean(confidences), 'Confidence Std': np.std(confidences), 'Average Severity': avg_severity.tolist(), 'Average Fluency': avg_fluency, 'Fluency Std': std_fluency } serializable_summary = { k: float(v) if isinstance(v, (np.floating, np.integer)) else v for k, v in summary_stats.items() } with open(os.path.join(output_dir, "summary_statistics.json"), "w") as f: json.dump(serializable_summary, f, indent=2) log_message("Comprehensive Classification Report:") log_message(df_report.to_string()) log_message(f"Comprehensive results saved to {output_dir}") return results_df, df_report, summary_stats # Main training function with adaptive learning rate def train_adaptive_model(json_file: str, output_dir: str = "./adaptive_aphasia_model"): """Main training function with adaptive learning rate""" log_message("Starting Adaptive Aphasia Classification Training") log_message("=" * 60) # Setup config = ModelConfig() os.makedirs(output_dir, exist_ok=True) # Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log_message(f"Using device: {device}") # Load data log_message("Loading dataset...") with open(json_file, "r", encoding="utf-8") as f: dataset_json = json.load(f) sentences = dataset_json.get("sentences", []) # Normalize aphasia types for item in sentences: if "aphasia_type" in item: item["aphasia_type"] = normalize_type(item["aphasia_type"]) # Aphasia types mapping aphasia_types_mapping = { "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2, "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5, "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8 } log_message(f"Aphasia Types Mapping: {aphasia_types_mapping}") num_labels = len(aphasia_types_mapping) log_message(f"Number of labels: {num_labels}") # Filter sentences filtered_sentences = [] for item in sentences: aphasia_type = item.get("aphasia_type", "") if aphasia_type in aphasia_types_mapping: filtered_sentences.append(item) else: log_message(f"Excluding sentence with invalid type: {aphasia_type}") log_message(f"Filtered dataset: {len(filtered_sentences)} sentences") # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained(config.model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Create dataset random.shuffle(filtered_sentences) dataset_all = StableAphasiaDataset( filtered_sentences, tokenizer, aphasia_types_mapping, config ) # Split dataset total_samples = len(dataset_all) train_size = int(0.8 * total_samples) eval_size = total_samples - train_size train_dataset, eval_dataset = torch.utils.data.random_split( dataset_all, [train_size, eval_size] ) log_message(f"Train size: {train_size}, Eval size: {eval_size}") # Setup weighted sampling for class imbalance train_labels = [dataset_all.samples[idx]["labels"].item() for idx in train_dataset.indices] label_counts = Counter(train_labels) sample_weights = [1.0 / label_counts[label] for label in train_labels] sampler = WeightedRandomSampler( weights=sample_weights, num_samples=len(sample_weights), replacement=True ) # Model initialization def model_init(): model = StableAphasiaClassifier(config, num_labels) model.bert.resize_token_embeddings(len(tokenizer)) return model.to(device) # Training arguments training_args = TrainingArguments( output_dir=output_dir, eval_strategy="epoch", save_strategy="epoch", learning_rate=config.learning_rate, per_device_train_batch_size=config.batch_size, per_device_eval_batch_size=config.batch_size, num_train_epochs=config.num_epochs, weight_decay=config.weight_decay, warmup_ratio=config.warmup_ratio, logging_strategy="steps", logging_steps=50, seed=42, dataloader_num_workers=0, gradient_accumulation_steps=config.gradient_accumulation_steps, max_grad_norm=1.0, fp16=False, dataloader_drop_last=True, report_to=None, load_best_model_at_end=True, metric_for_best_model="eval_f1", greater_is_better=True, save_total_limit=3, remove_unused_columns=False, ) # Initialize trainer with adaptive callback trainer = Trainer( model_init=model_init, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=compute_comprehensive_metrics, data_collator=stable_collate_fn, callbacks=[AdaptiveTrainingCallback(config, patience=5, min_delta=0.8)] ) # Start training log_message("Starting adaptive training...") try: trainer.train() log_message("Training completed successfully!") except Exception as e: log_message(f"Training error: {str(e)}") import traceback log_message(traceback.format_exc()) raise # Final evaluation log_message("Starting final evaluation...") eval_results = trainer.evaluate() log_message(f"Final evaluation results: {eval_results}") # Generate comprehensive reports results_df, report_df, summary_stats = generate_comprehensive_reports( trainer, eval_dataset, aphasia_types_mapping, tokenizer, output_dir ) # Save model model_to_save = trainer.model if hasattr(model_to_save, 'module'): model_to_save = model_to_save.module torch.save(model_to_save.state_dict(), os.path.join(output_dir, "pytorch_model.bin")) tokenizer.save_pretrained(output_dir) # Save configuration config_dict = { "model_name": config.model_name, "num_labels": num_labels, "aphasia_types_mapping": aphasia_types_mapping, "training_args": training_args.to_dict(), "adaptive_lr_config": { "adaptive_lr": config.adaptive_lr, "lr_patience": config.lr_patience, "lr_factor": config.lr_factor, "lr_increase_factor": config.lr_increase_factor, "min_lr": config.min_lr, "max_lr": config.max_lr, "oscillation_amplitude": config.oscillation_amplitude } } with open(os.path.join(output_dir, "config.json"), "w") as f: json.dump(config_dict, f, indent=2) log_message(f"Adaptive model and comprehensive reports saved to {output_dir}") clear_memory() return trainer, eval_results, results_df # Cross-validation with adaptive learning rate def train_adaptive_cross_validation(json_file: str, output_dir: str = "./adaptive_cv_results", n_folds: int = 5): """Cross-validation training with adaptive learning rate""" log_message("Starting Adaptive Cross-Validation Training") config = ModelConfig() os.makedirs(output_dir, exist_ok=True) # Load and prepare data with open(json_file, "r", encoding="utf-8") as f: dataset_json = json.load(f) sentences = dataset_json.get("sentences", []) # Normalize and filter for item in sentences: if "aphasia_type" in item: item["aphasia_type"] = normalize_type(item["aphasia_type"]) aphasia_types_mapping = { "BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2, "CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5, "GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8 } filtered_sentences = [s for s in sentences if s.get("aphasia_type") in aphasia_types_mapping] # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained(config.model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Create full dataset full_dataset = StableAphasiaDataset( filtered_sentences, tokenizer, aphasia_types_mapping, config ) # Extract labels for stratification sample_labels = [sample["labels"].item() for sample in full_dataset.samples] # Cross-validation skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42) fold_results = [] all_predictions = [] all_true_labels = [] for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(sample_labels)), sample_labels)): log_message(f"\n=== Fold {fold + 1}/{n_folds} ===") train_subset = Subset(full_dataset, train_idx) val_subset = Subset(full_dataset, val_idx) # Train single fold fold_trainer, fold_results_dict, fold_predictions = train_adaptive_single_fold( train_subset, val_subset, config, aphasia_types_mapping, tokenizer, fold, output_dir ) fold_results.append({ 'fold': fold + 1, **fold_results_dict }) # Collect predictions for ensemble analysis all_predictions.extend(fold_predictions['predictions']) all_true_labels.extend(fold_predictions['true_labels']) clear_memory() # Aggregate results results_df = pd.DataFrame(fold_results) results_df.to_csv(os.path.join(output_dir, "adaptive_cv_summary.csv"), index=False) # Cross-validation summary statistics cv_summary = { 'mean_accuracy': results_df['accuracy'].mean(), 'std_accuracy': results_df['accuracy'].std(), 'mean_f1': results_df['f1'].mean(), 'std_f1': results_df['f1'].std(), 'mean_f1_macro': results_df['f1_macro'].mean(), 'std_f1_macro': results_df['f1_macro'].std(), 'mean_precision': results_df['precision_macro'].mean(), 'std_precision': results_df['precision_macro'].std(), 'mean_recall': results_df['recall_macro'].mean(), 'std_recall': results_df['recall_macro'].std() } with open(os.path.join(output_dir, "cv_statistics.json"), "w") as f: json.dump(cv_summary, f, indent=2) # Overall confusion matrix across all folds overall_cm = confusion_matrix(all_true_labels, all_predictions) plt.figure(figsize=(12, 10)) sns.heatmap(overall_cm, annot=True, fmt="d", cmap="Blues", xticklabels=list(aphasia_types_mapping.keys()), yticklabels=list(aphasia_types_mapping.keys())) plt.xlabel("Predicted Label") plt.ylabel("True Label") plt.title("Overall Confusion Matrix (All Folds)") plt.xticks(rotation=45) plt.yticks(rotation=0) plt.tight_layout() plt.savefig(os.path.join(output_dir, "overall_confusion_matrix.png"), dpi=300, bbox_inches='tight') plt.close() # Cross-validation results visualization fig, axes = plt.subplots(2, 2, figsize=(15, 12)) # Accuracy across folds axes[0, 0].bar(range(1, n_folds + 1), results_df['accuracy'], color='skyblue', alpha=0.8) axes[0, 0].axhline(y=results_df['accuracy'].mean(), color='red', linestyle='--', label=f'Mean: {results_df["accuracy"].mean():.3f}') axes[0, 0].set_title('Accuracy Across Folds') axes[0, 0].set_xlabel('Fold') axes[0, 0].set_ylabel('Accuracy') axes[0, 0].legend() axes[0, 0].grid(True, alpha=0.3) # F1 Score across folds axes[0, 1].bar(range(1, n_folds + 1), results_df['f1'], color='lightgreen', alpha=0.8) axes[0, 1].axhline(y=results_df['f1'].mean(), color='red', linestyle='--', label=f'Mean: {results_df["f1"].mean():.3f}') axes[0, 1].set_title('F1 Score Across Folds') axes[0, 1].set_xlabel('Fold') axes[0, 1].set_ylabel('F1 Score') axes[0, 1].legend() axes[0, 1].grid(True, alpha=0.3) # Precision across folds axes[1, 0].bar(range(1, n_folds + 1), results_df['precision_macro'], color='coral', alpha=0.8) axes[1, 0].axhline(y=results_df['precision_macro'].mean(), color='red', linestyle='--', label=f'Mean: {results_df["precision_macro"].mean():.3f}') axes[1, 0].set_title('Precision Across Folds') axes[1, 0].set_xlabel('Fold') axes[1, 0].set_ylabel('Precision') axes[1, 0].legend() axes[1, 0].grid(True, alpha=0.3) # Recall across folds axes[1, 1].bar(range(1, n_folds + 1), results_df['recall_macro'], color='gold', alpha=0.8) axes[1, 1].axhline(y=results_df['recall_macro'].mean(), color='red', linestyle='--', label=f'Mean: {results_df["recall_macro"].mean():.3f}') axes[1, 1].set_title('Recall Across Folds') axes[1, 1].set_xlabel('Fold') axes[1, 1].set_ylabel('Recall') axes[1, 1].legend() axes[1, 1].grid(True, alpha=0.3) plt.tight_layout() plt.savefig(os.path.join(output_dir, "cv_performance_comparison.png"), dpi=300, bbox_inches='tight') plt.close() log_message("\n=== Adaptive Cross-Validation Summary ===") log_message(results_df.to_string(index=False)) # Statistics log_message(f"\nMean F1: {results_df['f1'].mean():.4f} ± {results_df['f1'].std():.4f}") log_message(f"Mean Accuracy: {results_df['accuracy'].mean():.4f} ± {results_df['accuracy'].std():.4f}") log_message(f"Mean F1 Macro: {results_df['f1_macro'].mean():.4f} ± {results_df['f1_macro'].std():.4f}") return results_df, cv_summary def train_adaptive_single_fold(train_dataset, val_dataset, config, aphasia_types_mapping, tokenizer, fold, output_dir): """Train a single fold with adaptive learning rate""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_labels = len(aphasia_types_mapping) # Setup weighted sampling train_labels = [train_dataset[i]["labels"].item() for i in range(len(train_dataset))] label_counts = Counter(train_labels) sample_weights = [1.0 / label_counts[label] for label in train_labels] sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True) # Model initialization def model_init(): model = StableAphasiaClassifier(config, num_labels) model.bert.resize_token_embeddings(len(tokenizer)) return model.to(device) # Training arguments fold_output_dir = os.path.join(output_dir, f"fold_{fold}") os.makedirs(fold_output_dir, exist_ok=True) training_args = TrainingArguments( output_dir=fold_output_dir, eval_strategy="epoch", save_strategy="epoch", learning_rate=config.learning_rate, per_device_train_batch_size=config.batch_size, per_device_eval_batch_size=config.batch_size, num_train_epochs=config.num_epochs, weight_decay=config.weight_decay, warmup_ratio=config.warmup_ratio, logging_steps=50, seed=42, dataloader_num_workers=0, gradient_accumulation_steps=config.gradient_accumulation_steps, max_grad_norm=1.0, fp16=False, dataloader_drop_last=True, report_to=None, load_best_model_at_end=True, metric_for_best_model="eval_f1", greater_is_better=True, save_total_limit=1, remove_unused_columns=False, ) # Trainer with adaptive callback trainer = Trainer( model_init=model_init, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_comprehensive_metrics, data_collator=stable_collate_fn, callbacks=[AdaptiveTrainingCallback(config, patience=5, min_delta=0.8)] ) # Train trainer.train() # Evaluate eval_results = trainer.evaluate() # Get predictions for ensemble analysis predictions = trainer.predict(val_dataset) pred_labels = np.argmax(predictions.predictions[0] if isinstance(predictions.predictions, tuple) else predictions.predictions, axis=1) true_labels = predictions.label_ids fold_predictions = { 'predictions': pred_labels.tolist(), 'true_labels': true_labels.tolist() } # Save fold model model_to_save = trainer.model if hasattr(model_to_save, 'module'): model_to_save = model_to_save.module torch.save(model_to_save.state_dict(), os.path.join(fold_output_dir, "pytorch_model.bin")) return trainer, eval_results, fold_predictions # Main execution if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Adaptive Learning Rate Aphasia Classification Training") parser.add_argument("--output_dir", type=str, default="./adaptive_aphasia_model", help="Output directory") parser.add_argument("--cross_validation", action="store_true", help="Use cross-validation") parser.add_argument("--n_folds", type=int, default=5, help="Number of CV folds") parser.add_argument("--json_file", type=str, default=json_file, help="Path to JSON dataset file") parser.add_argument("--learning_rate", type=float, default=5e-4, help="Initial learning rate") parser.add_argument("--batch_size", type=int, default=24, help="Batch size") parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs") parser.add_argument("--adaptive_lr", action="store_true", default=True, help="Use adaptive learning rate") args = parser.parse_args() # Update config with command line arguments config = ModelConfig() config.learning_rate = args.learning_rate config.batch_size = args.batch_size config.num_epochs = args.num_epochs config.adaptive_lr = args.adaptive_lr try: clear_memory() log_message(f"Starting training with adaptive_lr={config.adaptive_lr}") log_message(f"Config: lr={config.learning_rate}, batch_size={config.batch_size}, epochs={config.num_epochs}") if args.cross_validation: results_df, cv_summary = train_adaptive_cross_validation(args.json_file, args.output_dir, args.n_folds) log_message("Cross-validation training completed!") else: trainer, eval_results, results_df = train_adaptive_model(args.json_file, args.output_dir) log_message("Single model training completed!") log_message("All adaptive training completed successfully!") except Exception as e: log_message(f"Training failed: {str(e)}") import traceback log_message(traceback.format_exc()) finally: clear_memory()