|
|
|
|
|
"""
|
|
|
Advanced Multi-Modal Aphasia Classification System
|
|
|
With Adaptive Learning Rate and Comprehensive Reporting
|
|
|
"""
|
|
|
|
|
|
import re
|
|
|
import json
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import time
|
|
|
import datetime
|
|
|
import numpy as np
|
|
|
import os
|
|
|
import random
|
|
|
import csv
|
|
|
import math
|
|
|
from collections import Counter, defaultdict
|
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
import torch.optim as optim
|
|
|
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset
|
|
|
from transformers import (
|
|
|
AutoTokenizer, AutoModel, AutoConfig,
|
|
|
TrainingArguments, Trainer, TrainerCallback,
|
|
|
EarlyStoppingCallback, get_cosine_schedule_with_warmup,
|
|
|
default_data_collator, set_seed
|
|
|
)
|
|
|
|
|
|
import seaborn as sns
|
|
|
import matplotlib.pyplot as plt
|
|
|
import pandas as pd
|
|
|
from sklearn.metrics import (
|
|
|
accuracy_score, f1_score, precision_score, recall_score,
|
|
|
confusion_matrix, classification_report, roc_auc_score
|
|
|
)
|
|
|
from sklearn.model_selection import StratifiedKFold
|
|
|
import gc
|
|
|
from scipy import stats
|
|
|
|
|
|
|
|
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
|
os.environ["TORCH_USE_CUDA_DSA"] = "1"
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
json_file = '/workspace/SH001/aphasia_data_augmented.json'
|
|
|
|
|
|
|
|
|
def set_all_seeds(seed=42):
|
|
|
random.seed(seed)
|
|
|
np.random.seed(seed)
|
|
|
torch.manual_seed(seed)
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
|
|
|
|
set_all_seeds(42)
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class ModelConfig:
|
|
|
|
|
|
model_name: str = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
|
|
|
max_length: int = 512
|
|
|
hidden_size: int = 768
|
|
|
|
|
|
|
|
|
pos_vocab_size: int = 150
|
|
|
pos_emb_dim: int = 64
|
|
|
grammar_dim: int = 3
|
|
|
grammar_hidden_dim: int = 64
|
|
|
duration_hidden_dim: int = 128
|
|
|
prosody_dim: int = 32
|
|
|
|
|
|
|
|
|
num_attention_heads: int = 8
|
|
|
attention_dropout: float = 0.3
|
|
|
|
|
|
|
|
|
classifier_hidden_dims: List[int] = None
|
|
|
dropout_rate: float = 0.3
|
|
|
activation_fn: str = "tanh"
|
|
|
|
|
|
|
|
|
learning_rate: float = 5e-4
|
|
|
weight_decay: float = 0.01
|
|
|
warmup_ratio: float = 0.1
|
|
|
batch_size: int = 10
|
|
|
num_epochs: int = 500
|
|
|
gradient_accumulation_steps: int = 4
|
|
|
|
|
|
|
|
|
adaptive_lr: bool = True
|
|
|
lr_patience: int = 3
|
|
|
lr_factor: float = 0.8
|
|
|
lr_increase_factor: float = 1.2
|
|
|
min_lr: float = 1e-6
|
|
|
max_lr: float = 1e-3
|
|
|
oscillation_amplitude: float = 0.1
|
|
|
|
|
|
|
|
|
use_focal_loss: bool = True
|
|
|
focal_alpha: float = 1.0
|
|
|
focal_gamma: float = 2.0
|
|
|
use_mixup: bool = False
|
|
|
mixup_alpha: float = 0.2
|
|
|
use_label_smoothing: bool = True
|
|
|
label_smoothing: float = 0.1
|
|
|
|
|
|
def __post_init__(self):
|
|
|
if self.classifier_hidden_dims is None:
|
|
|
self.classifier_hidden_dims = [512, 256]
|
|
|
|
|
|
|
|
|
def log_message(message):
|
|
|
timestamp = datetime.datetime.now().isoformat()
|
|
|
full_message = f"{timestamp}: {message}"
|
|
|
log_file = "./training_log.txt"
|
|
|
with open(log_file, "a", encoding="utf-8") as f:
|
|
|
f.write(full_message + "\n")
|
|
|
print(full_message, flush=True)
|
|
|
|
|
|
def clear_memory():
|
|
|
gc.collect()
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def normalize_type(t):
|
|
|
return t.strip().upper() if isinstance(t, str) else t
|
|
|
|
|
|
|
|
|
class AdaptiveLearningRateScheduler:
|
|
|
"""智能學習率調度器,結合多種策略"""
|
|
|
def __init__(self, optimizer, config: ModelConfig, total_steps: int):
|
|
|
self.optimizer = optimizer
|
|
|
self.config = config
|
|
|
self.total_steps = total_steps
|
|
|
|
|
|
|
|
|
self.loss_history = []
|
|
|
self.f1_history = []
|
|
|
self.accuracy_history = []
|
|
|
self.lr_history = []
|
|
|
|
|
|
|
|
|
self.plateau_counter = 0
|
|
|
self.best_f1 = 0.0
|
|
|
self.best_loss = float('inf')
|
|
|
self.step_count = 0
|
|
|
|
|
|
|
|
|
self.base_lr = config.learning_rate
|
|
|
self.current_lr = self.base_lr
|
|
|
|
|
|
log_message(f"Adaptive LR Scheduler initialized with base_lr={self.base_lr}")
|
|
|
|
|
|
def calculate_slope(self, values, window=3):
|
|
|
"""計算近期數值的斜率"""
|
|
|
if len(values) < window:
|
|
|
return 0.0
|
|
|
|
|
|
recent_values = values[-window:]
|
|
|
x = np.arange(len(recent_values))
|
|
|
slope, _, _, _, _ = stats.linregress(x, recent_values)
|
|
|
return slope
|
|
|
|
|
|
def exponential_adjustment(self, current_value, target_value, base_factor=1.1):
|
|
|
"""指數調整函數"""
|
|
|
ratio = current_value / target_value if target_value != 0 else 1.0
|
|
|
factor = math.exp(-ratio) * base_factor
|
|
|
return factor
|
|
|
|
|
|
def logarithmic_adjustment(self, current_value, threshold=0.1):
|
|
|
"""對數調整函數"""
|
|
|
if current_value <= 0:
|
|
|
return 1.0
|
|
|
factor = math.log(1 + current_value / threshold)
|
|
|
return max(0.5, min(2.0, factor))
|
|
|
|
|
|
def sinusoidal_oscillation(self, step, amplitude=None):
|
|
|
"""正弦波動調整"""
|
|
|
if amplitude is None:
|
|
|
amplitude = self.config.oscillation_amplitude
|
|
|
|
|
|
|
|
|
phase = 2 * math.pi * step / (self.total_steps / 4)
|
|
|
oscillation = 1 + amplitude * math.sin(phase)
|
|
|
return oscillation
|
|
|
|
|
|
def cosine_decay(self, step):
|
|
|
"""餘弦衰減"""
|
|
|
progress = step / self.total_steps
|
|
|
decay = 0.5 * (1 + math.cos(math.pi * progress))
|
|
|
return decay
|
|
|
|
|
|
def adaptive_lr_calculation(self, current_loss, current_f1, current_acc):
|
|
|
"""智能學習率計算"""
|
|
|
|
|
|
self.loss_history.append(current_loss)
|
|
|
self.f1_history.append(current_f1)
|
|
|
self.accuracy_history.append(current_acc)
|
|
|
|
|
|
|
|
|
loss_slope = self.calculate_slope(self.loss_history)
|
|
|
f1_slope = self.calculate_slope(self.f1_history)
|
|
|
acc_slope = self.calculate_slope(self.accuracy_history)
|
|
|
|
|
|
|
|
|
adjustment_factor = 1.0
|
|
|
|
|
|
|
|
|
if abs(loss_slope) < 0.001:
|
|
|
log_message(f"Loss plateau detected (slope: {loss_slope:.6f})")
|
|
|
|
|
|
exp_factor = self.exponential_adjustment(abs(loss_slope), 0.01, 1.15)
|
|
|
adjustment_factor *= exp_factor
|
|
|
|
|
|
elif current_loss > 2.0:
|
|
|
log_message(f"High loss detected: {current_loss:.4f}")
|
|
|
|
|
|
log_factor = self.logarithmic_adjustment(current_loss, 1.0)
|
|
|
adjustment_factor *= log_factor
|
|
|
|
|
|
|
|
|
if current_f1 < 0.3:
|
|
|
log_message(f"Low F1 detected: {current_f1:.4f}")
|
|
|
|
|
|
exp_factor = self.exponential_adjustment(0.3, current_f1, 1.2)
|
|
|
adjustment_factor *= exp_factor
|
|
|
|
|
|
elif abs(f1_slope) < 0.001:
|
|
|
log_message(f"F1 plateau detected (slope: {f1_slope:.6f})")
|
|
|
adjustment_factor *= 1.1
|
|
|
|
|
|
|
|
|
sin_factor = self.sinusoidal_oscillation(self.step_count)
|
|
|
|
|
|
|
|
|
cos_factor = self.cosine_decay(self.step_count)
|
|
|
|
|
|
|
|
|
final_factor = adjustment_factor * sin_factor * (0.3 + 0.7 * cos_factor)
|
|
|
|
|
|
|
|
|
new_lr = self.current_lr * final_factor
|
|
|
|
|
|
|
|
|
new_lr = max(self.config.min_lr, min(self.config.max_lr, new_lr))
|
|
|
|
|
|
|
|
|
if abs(new_lr - self.current_lr) > 1e-7:
|
|
|
self.current_lr = new_lr
|
|
|
for param_group in self.optimizer.param_groups:
|
|
|
param_group['lr'] = new_lr
|
|
|
|
|
|
log_message(f"Learning rate adjusted: {new_lr:.2e} (factor: {final_factor:.3f})")
|
|
|
log_message(f" - Loss slope: {loss_slope:.6f}, F1 slope: {f1_slope:.6f}")
|
|
|
log_message(f" - Sin factor: {sin_factor:.3f}, Cos factor: {cos_factor:.3f}")
|
|
|
|
|
|
self.lr_history.append(self.current_lr)
|
|
|
self.step_count += 1
|
|
|
|
|
|
return self.current_lr
|
|
|
|
|
|
|
|
|
class TrainingHistoryTracker:
|
|
|
"""訓練歷史記錄器"""
|
|
|
def __init__(self):
|
|
|
self.history = {
|
|
|
'epoch': [],
|
|
|
'train_loss': [],
|
|
|
'eval_loss': [],
|
|
|
'train_accuracy': [],
|
|
|
'eval_accuracy': [],
|
|
|
'train_f1': [],
|
|
|
'eval_f1': [],
|
|
|
'learning_rate': [],
|
|
|
'train_precision': [],
|
|
|
'eval_precision': [],
|
|
|
'train_recall': [],
|
|
|
'eval_recall': []
|
|
|
}
|
|
|
|
|
|
def update(self, epoch, metrics):
|
|
|
"""更新歷史記錄"""
|
|
|
self.history['epoch'].append(epoch)
|
|
|
for key, value in metrics.items():
|
|
|
if key in self.history:
|
|
|
self.history[key].append(value)
|
|
|
|
|
|
def save_history(self, output_dir):
|
|
|
"""保存歷史記錄"""
|
|
|
df = pd.DataFrame(self.history)
|
|
|
df.to_csv(os.path.join(output_dir, "training_history.csv"), index=False)
|
|
|
return df
|
|
|
|
|
|
def plot_training_curves(self, output_dir):
|
|
|
"""繪製訓練曲線"""
|
|
|
if not self.history['epoch']:
|
|
|
return
|
|
|
|
|
|
|
|
|
plt.style.use('seaborn-v0_8')
|
|
|
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
|
|
|
|
|
|
epochs = self.history['epoch']
|
|
|
|
|
|
|
|
|
axes[0, 0].plot(epochs, self.history['train_loss'], 'b-', label='Train Loss', linewidth=2)
|
|
|
axes[0, 0].plot(epochs, self.history['eval_loss'], 'r-', label='Eval Loss', linewidth=2)
|
|
|
axes[0, 0].set_title('Loss Over Time', fontsize=14, fontweight='bold')
|
|
|
axes[0, 0].set_xlabel('Epoch')
|
|
|
axes[0, 0].set_ylabel('Loss')
|
|
|
axes[0, 0].legend()
|
|
|
axes[0, 0].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
axes[0, 1].plot(epochs, self.history['train_accuracy'], 'b-', label='Train Accuracy', linewidth=2)
|
|
|
axes[0, 1].plot(epochs, self.history['eval_accuracy'], 'r-', label='Eval Accuracy', linewidth=2)
|
|
|
axes[0, 1].set_title('Accuracy Over Time', fontsize=14, fontweight='bold')
|
|
|
axes[0, 1].set_xlabel('Epoch')
|
|
|
axes[0, 1].set_ylabel('Accuracy')
|
|
|
axes[0, 1].legend()
|
|
|
axes[0, 1].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
axes[0, 2].plot(epochs, self.history['train_f1'], 'b-', label='Train F1', linewidth=2)
|
|
|
axes[0, 2].plot(epochs, self.history['eval_f1'], 'r-', label='Eval F1', linewidth=2)
|
|
|
axes[0, 2].set_title('F1 Score Over Time', fontsize=14, fontweight='bold')
|
|
|
axes[0, 2].set_xlabel('Epoch')
|
|
|
axes[0, 2].set_ylabel('F1 Score')
|
|
|
axes[0, 2].legend()
|
|
|
axes[0, 2].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
axes[1, 0].plot(epochs, self.history['learning_rate'], 'g-', linewidth=2)
|
|
|
axes[1, 0].set_title('Learning Rate Over Time', fontsize=14, fontweight='bold')
|
|
|
axes[1, 0].set_xlabel('Epoch')
|
|
|
axes[1, 0].set_ylabel('Learning Rate')
|
|
|
axes[1, 0].set_yscale('log')
|
|
|
axes[1, 0].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
axes[1, 1].plot(epochs, self.history['train_precision'], 'b-', label='Train Precision', linewidth=2)
|
|
|
axes[1, 1].plot(epochs, self.history['eval_precision'], 'r-', label='Eval Precision', linewidth=2)
|
|
|
axes[1, 1].set_title('Precision Over Time', fontsize=14, fontweight='bold')
|
|
|
axes[1, 1].set_xlabel('Epoch')
|
|
|
axes[1, 1].set_ylabel('Precision')
|
|
|
axes[1, 1].legend()
|
|
|
axes[1, 1].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
axes[1, 2].plot(epochs, self.history['train_recall'], 'b-', label='Train Recall', linewidth=2)
|
|
|
axes[1, 2].plot(epochs, self.history['eval_recall'], 'r-', label='Eval Recall', linewidth=2)
|
|
|
axes[1, 2].set_title('Recall Over Time', fontsize=14, fontweight='bold')
|
|
|
axes[1, 2].set_xlabel('Epoch')
|
|
|
axes[1, 2].set_ylabel('Recall')
|
|
|
axes[1, 2].legend()
|
|
|
axes[1, 2].grid(True, alpha=0.3)
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(os.path.join(output_dir, "training_curves.png"), dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module):
|
|
|
def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
|
|
|
super().__init__()
|
|
|
self.alpha = alpha
|
|
|
self.gamma = gamma
|
|
|
self.reduction = reduction
|
|
|
|
|
|
def forward(self, inputs, targets):
|
|
|
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
|
|
|
pt = torch.exp(-ce_loss)
|
|
|
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
|
|
|
|
|
|
if self.reduction == 'mean':
|
|
|
return focal_loss.mean()
|
|
|
elif self.reduction == 'sum':
|
|
|
return focal_loss.sum()
|
|
|
else:
|
|
|
return focal_loss
|
|
|
|
|
|
|
|
|
class StablePositionalEncoding(nn.Module):
|
|
|
"""Simplified but stable positional encoding"""
|
|
|
def __init__(self, d_model: int, max_len: int = 5000):
|
|
|
super().__init__()
|
|
|
self.d_model = d_model
|
|
|
|
|
|
|
|
|
pe = torch.zeros(max_len, d_model)
|
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
|
|
(-math.log(10000.0) / d_model))
|
|
|
|
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
|
|
|
self.register_buffer('pe', pe.unsqueeze(0))
|
|
|
|
|
|
|
|
|
self.learnable_pe = nn.Parameter(torch.randn(max_len, d_model) * 0.01)
|
|
|
|
|
|
def forward(self, x):
|
|
|
seq_len = x.size(1)
|
|
|
sinusoidal = self.pe[:, :seq_len, :].to(x.device)
|
|
|
learnable = self.learnable_pe[:seq_len, :].unsqueeze(0).expand(x.size(0), -1, -1)
|
|
|
return x + 0.1 * (sinusoidal + learnable)
|
|
|
|
|
|
|
|
|
class StableMultiHeadAttention(nn.Module):
|
|
|
"""Stable multi-head attention for feature fusion"""
|
|
|
def __init__(self, feature_dim: int, num_heads: int = 4, dropout: float = 0.3):
|
|
|
super().__init__()
|
|
|
self.num_heads = num_heads
|
|
|
self.feature_dim = feature_dim
|
|
|
self.head_dim = feature_dim // num_heads
|
|
|
|
|
|
assert feature_dim % num_heads == 0
|
|
|
|
|
|
self.query = nn.Linear(feature_dim, feature_dim)
|
|
|
self.key = nn.Linear(feature_dim, feature_dim)
|
|
|
self.value = nn.Linear(feature_dim, feature_dim)
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
self.output_proj = nn.Linear(feature_dim, feature_dim)
|
|
|
self.layer_norm = nn.LayerNorm(feature_dim)
|
|
|
|
|
|
def forward(self, x, mask=None):
|
|
|
batch_size, seq_len, _ = x.size()
|
|
|
|
|
|
Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
|
|
|
|
if mask is not None:
|
|
|
if mask.dim() == 2:
|
|
|
mask = mask.unsqueeze(1).unsqueeze(1)
|
|
|
scores.masked_fill_(mask == 0, -1e9)
|
|
|
|
|
|
attn_weights = F.softmax(scores, dim=-1)
|
|
|
attn_weights = self.dropout(attn_weights)
|
|
|
|
|
|
context = torch.matmul(attn_weights, V)
|
|
|
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.feature_dim)
|
|
|
|
|
|
output = self.output_proj(context)
|
|
|
return self.layer_norm(output + x)
|
|
|
|
|
|
|
|
|
class StableLinguisticFeatureExtractor(nn.Module):
|
|
|
"""Stable linguistic feature processing"""
|
|
|
def __init__(self, config: ModelConfig):
|
|
|
super().__init__()
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
self.pos_embedding = nn.Embedding(config.pos_vocab_size, config.pos_emb_dim, padding_idx=0)
|
|
|
self.pos_attention = StableMultiHeadAttention(config.pos_emb_dim, num_heads=4)
|
|
|
|
|
|
|
|
|
self.grammar_projection = nn.Sequential(
|
|
|
nn.Linear(config.grammar_dim, config.grammar_hidden_dim),
|
|
|
nn.Tanh(),
|
|
|
nn.LayerNorm(config.grammar_hidden_dim),
|
|
|
nn.Dropout(config.dropout_rate * 0.3)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.duration_projection = nn.Sequential(
|
|
|
nn.Linear(1, config.duration_hidden_dim),
|
|
|
nn.Tanh(),
|
|
|
nn.LayerNorm(config.duration_hidden_dim)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.prosody_projection = nn.Sequential(
|
|
|
nn.Linear(config.prosody_dim, config.prosody_dim),
|
|
|
nn.ReLU(),
|
|
|
nn.LayerNorm(config.prosody_dim)
|
|
|
)
|
|
|
|
|
|
|
|
|
total_feature_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
|
|
|
config.duration_hidden_dim + config.prosody_dim)
|
|
|
self.feature_fusion = nn.Sequential(
|
|
|
nn.Linear(total_feature_dim, total_feature_dim // 2),
|
|
|
nn.Tanh(),
|
|
|
nn.LayerNorm(total_feature_dim // 2),
|
|
|
nn.Dropout(config.dropout_rate)
|
|
|
)
|
|
|
|
|
|
def forward(self, pos_ids, grammar_ids, durations, prosody_features, attention_mask):
|
|
|
batch_size, seq_len = pos_ids.size()
|
|
|
|
|
|
|
|
|
pos_ids_clamped = pos_ids.clamp(0, self.config.pos_vocab_size - 1)
|
|
|
pos_embeds = self.pos_embedding(pos_ids_clamped)
|
|
|
pos_features = self.pos_attention(pos_embeds, attention_mask)
|
|
|
|
|
|
|
|
|
grammar_features = self.grammar_projection(grammar_ids.float())
|
|
|
|
|
|
|
|
|
duration_features = self.duration_projection(durations.unsqueeze(-1).float())
|
|
|
|
|
|
|
|
|
prosody_features = self.prosody_projection(prosody_features.float())
|
|
|
|
|
|
|
|
|
combined_features = torch.cat([
|
|
|
pos_features, grammar_features, duration_features, prosody_features
|
|
|
], dim=-1)
|
|
|
|
|
|
|
|
|
fused_features = self.feature_fusion(combined_features)
|
|
|
|
|
|
|
|
|
mask_expanded = attention_mask.unsqueeze(-1).float()
|
|
|
pooled_features = torch.sum(fused_features * mask_expanded, dim=1) / torch.sum(mask_expanded, dim=1)
|
|
|
|
|
|
return pooled_features
|
|
|
|
|
|
|
|
|
class StableAphasiaClassifier(nn.Module):
|
|
|
"""Stable aphasia classification model"""
|
|
|
def __init__(self, config: ModelConfig, num_labels: int):
|
|
|
super().__init__()
|
|
|
self.config = config
|
|
|
self.num_labels = num_labels
|
|
|
|
|
|
|
|
|
self.bert = AutoModel.from_pretrained(config.model_name)
|
|
|
self.bert_config = self.bert.config
|
|
|
|
|
|
|
|
|
for param in self.bert.embeddings.parameters():
|
|
|
param.requires_grad = False
|
|
|
|
|
|
|
|
|
self.positional_encoder = StablePositionalEncoding(
|
|
|
d_model=self.bert_config.hidden_size,
|
|
|
max_len=config.max_length
|
|
|
)
|
|
|
|
|
|
|
|
|
self.linguistic_extractor = StableLinguisticFeatureExtractor(config)
|
|
|
|
|
|
|
|
|
bert_dim = self.bert_config.hidden_size
|
|
|
linguistic_dim = (config.pos_emb_dim + config.grammar_hidden_dim +
|
|
|
config.duration_hidden_dim + config.prosody_dim) // 2
|
|
|
|
|
|
|
|
|
self.feature_fusion = nn.Sequential(
|
|
|
nn.Linear(bert_dim + linguistic_dim, bert_dim),
|
|
|
nn.LayerNorm(bert_dim),
|
|
|
nn.Tanh(),
|
|
|
nn.Dropout(config.dropout_rate)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.classifier = self._build_classifier(bert_dim, num_labels)
|
|
|
|
|
|
|
|
|
self.severity_head = nn.Sequential(
|
|
|
nn.Linear(bert_dim, 4),
|
|
|
nn.Softmax(dim=-1)
|
|
|
)
|
|
|
|
|
|
self.fluency_head = nn.Sequential(
|
|
|
nn.Linear(bert_dim, 1),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
|
|
|
def _build_classifier(self, input_dim: int, num_labels: int):
|
|
|
layers = []
|
|
|
current_dim = input_dim
|
|
|
|
|
|
for hidden_dim in self.config.classifier_hidden_dims:
|
|
|
layers.extend([
|
|
|
nn.Linear(current_dim, hidden_dim),
|
|
|
nn.LayerNorm(hidden_dim),
|
|
|
nn.Tanh(),
|
|
|
nn.Dropout(self.config.dropout_rate)
|
|
|
])
|
|
|
current_dim = hidden_dim
|
|
|
|
|
|
layers.append(nn.Linear(current_dim, num_labels))
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
def forward(self, input_ids, attention_mask, labels=None,
|
|
|
word_pos_ids=None, word_grammar_ids=None, word_durations=None,
|
|
|
prosody_features=None, **kwargs):
|
|
|
|
|
|
|
|
|
bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
sequence_output = bert_outputs.last_hidden_state
|
|
|
|
|
|
|
|
|
position_enhanced = self.positional_encoder(sequence_output)
|
|
|
|
|
|
|
|
|
pooled_output = self._attention_pooling(position_enhanced, attention_mask)
|
|
|
|
|
|
|
|
|
if all(x is not None for x in [word_pos_ids, word_grammar_ids, word_durations]):
|
|
|
if prosody_features is None:
|
|
|
batch_size, seq_len = input_ids.size()
|
|
|
prosody_features = torch.zeros(
|
|
|
batch_size, seq_len, self.config.prosody_dim,
|
|
|
device=input_ids.device
|
|
|
)
|
|
|
|
|
|
linguistic_features = self.linguistic_extractor(
|
|
|
word_pos_ids, word_grammar_ids, word_durations,
|
|
|
prosody_features, attention_mask
|
|
|
)
|
|
|
else:
|
|
|
linguistic_features = torch.zeros(
|
|
|
input_ids.size(0),
|
|
|
(self.config.pos_emb_dim + self.config.grammar_hidden_dim +
|
|
|
self.config.duration_hidden_dim + self.config.prosody_dim) // 2,
|
|
|
device=input_ids.device
|
|
|
)
|
|
|
|
|
|
|
|
|
combined_features = torch.cat([pooled_output, linguistic_features], dim=1)
|
|
|
fused_features = self.feature_fusion(combined_features)
|
|
|
|
|
|
|
|
|
logits = self.classifier(fused_features)
|
|
|
severity_pred = self.severity_head(fused_features)
|
|
|
fluency_pred = self.fluency_head(fused_features)
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
|
loss = self._compute_loss(logits, labels)
|
|
|
|
|
|
return {
|
|
|
"logits": logits,
|
|
|
"severity_pred": severity_pred,
|
|
|
"fluency_pred": fluency_pred,
|
|
|
"loss": loss
|
|
|
}
|
|
|
|
|
|
def _attention_pooling(self, sequence_output, attention_mask):
|
|
|
"""Attention-based pooling"""
|
|
|
attention_weights = torch.softmax(
|
|
|
torch.sum(sequence_output, dim=-1, keepdim=True), dim=1
|
|
|
)
|
|
|
attention_weights = attention_weights * attention_mask.unsqueeze(-1).float()
|
|
|
attention_weights = attention_weights / (torch.sum(attention_weights, dim=1, keepdim=True) + 1e-9)
|
|
|
pooled = torch.sum(sequence_output * attention_weights, dim=1)
|
|
|
return pooled
|
|
|
|
|
|
def _compute_loss(self, logits, labels):
|
|
|
if self.config.use_focal_loss:
|
|
|
focal_loss = FocalLoss(
|
|
|
alpha=self.config.focal_alpha,
|
|
|
gamma=self.config.focal_gamma,
|
|
|
reduction='mean'
|
|
|
)
|
|
|
return focal_loss(logits, labels)
|
|
|
else:
|
|
|
if self.config.use_label_smoothing:
|
|
|
return F.cross_entropy(
|
|
|
logits, labels,
|
|
|
label_smoothing=self.config.label_smoothing
|
|
|
)
|
|
|
else:
|
|
|
return F.cross_entropy(logits, labels)
|
|
|
|
|
|
|
|
|
class StableAphasiaDataset(Dataset):
|
|
|
"""Stable dataset with simplified processing"""
|
|
|
def __init__(self, sentences, tokenizer, aphasia_types_mapping, config: ModelConfig):
|
|
|
self.samples = []
|
|
|
self.tokenizer = tokenizer
|
|
|
self.config = config
|
|
|
self.aphasia_types_mapping = aphasia_types_mapping
|
|
|
|
|
|
|
|
|
special_tokens = ["[DIALOGUE]", "[TURN]", "[PAUSE]", "[REPEAT]", "[HESITATION]"]
|
|
|
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
|
|
|
|
|
|
for idx, item in enumerate(sentences):
|
|
|
sentence_id = item.get("sentence_id", f"S{idx}")
|
|
|
aphasia_type = normalize_type(item.get("aphasia_type", ""))
|
|
|
|
|
|
if aphasia_type not in aphasia_types_mapping:
|
|
|
log_message(f"Skipping Sentence {sentence_id}: Invalid aphasia type '{aphasia_type}'")
|
|
|
continue
|
|
|
|
|
|
self._process_sentence(item, sentence_id, aphasia_type)
|
|
|
|
|
|
if not self.samples:
|
|
|
raise ValueError("No valid samples found in dataset!")
|
|
|
|
|
|
log_message(f"Dataset created with {len(self.samples)} samples")
|
|
|
self._print_class_distribution()
|
|
|
|
|
|
def _process_sentence(self, item, sentence_id, aphasia_type):
|
|
|
"""Process sentence with stable approach"""
|
|
|
all_tokens, all_pos, all_grammar, all_durations = [], [], [], []
|
|
|
|
|
|
for dialogue_idx, dialogue in enumerate(item.get("dialogues", [])):
|
|
|
if dialogue_idx > 0:
|
|
|
all_tokens.append("[DIALOGUE]")
|
|
|
all_pos.append(0)
|
|
|
all_grammar.append([0, 0, 0])
|
|
|
all_durations.append(0.0)
|
|
|
|
|
|
for par in dialogue.get("PAR", []):
|
|
|
if "tokens" in par and par["tokens"]:
|
|
|
tokens = par["tokens"]
|
|
|
pos_ids = par.get("word_pos_ids", [0] * len(tokens))
|
|
|
grammar_ids = par.get("word_grammar_ids", [[0, 0, 0]] * len(tokens))
|
|
|
durations = par.get("word_durations", [0.0] * len(tokens))
|
|
|
|
|
|
all_tokens.extend(tokens)
|
|
|
all_pos.extend(pos_ids)
|
|
|
all_grammar.extend(grammar_ids)
|
|
|
all_durations.extend(durations)
|
|
|
|
|
|
if not all_tokens:
|
|
|
return
|
|
|
|
|
|
|
|
|
self._create_sample(all_tokens, all_pos, all_grammar, all_durations,
|
|
|
sentence_id, aphasia_type)
|
|
|
|
|
|
def _create_sample(self, tokens, pos_ids, grammar_ids, durations,
|
|
|
sentence_id, aphasia_type):
|
|
|
"""Create training sample"""
|
|
|
|
|
|
text = " ".join(tokens)
|
|
|
encoded = self.tokenizer(
|
|
|
text,
|
|
|
max_length=self.config.max_length,
|
|
|
padding="max_length",
|
|
|
truncation=True,
|
|
|
return_tensors="pt"
|
|
|
)
|
|
|
|
|
|
|
|
|
aligned_pos, aligned_grammar, aligned_durations = self._align_features(
|
|
|
tokens, pos_ids, grammar_ids, durations, encoded
|
|
|
)
|
|
|
|
|
|
|
|
|
prosody_features = self._extract_prosodic_features(durations, tokens)
|
|
|
prosody_tensor = torch.tensor(prosody_features).unsqueeze(0).repeat(
|
|
|
self.config.max_length, 1
|
|
|
)
|
|
|
|
|
|
label = self.aphasia_types_mapping[aphasia_type]
|
|
|
|
|
|
sample = {
|
|
|
"input_ids": encoded["input_ids"].squeeze(0),
|
|
|
"attention_mask": encoded["attention_mask"].squeeze(0),
|
|
|
"labels": torch.tensor(label, dtype=torch.long),
|
|
|
"word_pos_ids": torch.tensor(aligned_pos, dtype=torch.long),
|
|
|
"word_grammar_ids": torch.tensor(aligned_grammar, dtype=torch.long),
|
|
|
"word_durations": torch.tensor(aligned_durations, dtype=torch.float),
|
|
|
"prosody_features": prosody_tensor.float(),
|
|
|
"sentence_id": sentence_id
|
|
|
}
|
|
|
self.samples.append(sample)
|
|
|
|
|
|
def _align_features(self, tokens, pos_ids, grammar_ids, durations, encoded):
|
|
|
"""Align features with BERT subtokens"""
|
|
|
subtoken_to_token = []
|
|
|
|
|
|
for token_idx, token in enumerate(tokens):
|
|
|
subtokens = self.tokenizer.tokenize(token)
|
|
|
subtoken_to_token.extend([token_idx] * len(subtokens))
|
|
|
|
|
|
aligned_pos = [0]
|
|
|
aligned_grammar = [[0, 0, 0]]
|
|
|
aligned_durations = [0.0]
|
|
|
|
|
|
for subtoken_idx in range(1, self.config.max_length - 1):
|
|
|
if subtoken_idx - 1 < len(subtoken_to_token):
|
|
|
original_idx = subtoken_to_token[subtoken_idx - 1]
|
|
|
aligned_pos.append(pos_ids[original_idx] if original_idx < len(pos_ids) else 0)
|
|
|
aligned_grammar.append(grammar_ids[original_idx] if original_idx < len(grammar_ids) else [0, 0, 0])
|
|
|
raw = durations[original_idx] if original_idx < len(durations) else 0.0
|
|
|
if isinstance(raw, list) and (isinstance(raw[1], int) and isinstance(raw[0], int)):
|
|
|
if len(raw) >= 2:
|
|
|
duration_val = int(raw[1]) - int(raw[0])
|
|
|
else:
|
|
|
duration_val = raw[0]
|
|
|
else:
|
|
|
duration_val = 0.0
|
|
|
aligned_durations.append(duration_val)
|
|
|
else:
|
|
|
aligned_pos.append(0)
|
|
|
aligned_grammar.append([0, 0, 0])
|
|
|
aligned_durations.append(0.0)
|
|
|
|
|
|
aligned_pos.append(0)
|
|
|
aligned_grammar.append([0, 0, 0])
|
|
|
aligned_durations.append(0.0)
|
|
|
|
|
|
return aligned_pos, aligned_grammar, aligned_durations
|
|
|
|
|
|
def _extract_prosodic_features(self, durations, tokens):
|
|
|
"""Extract prosodic features"""
|
|
|
if not durations:
|
|
|
return [0.0] * self.config.prosody_dim
|
|
|
|
|
|
valid_durations = [d for d in durations if isinstance(d, (int, float)) and d > 0]
|
|
|
if not valid_durations:
|
|
|
return [0.0] * self.config.prosody_dim
|
|
|
|
|
|
features = [
|
|
|
np.mean(valid_durations),
|
|
|
np.std(valid_durations),
|
|
|
np.median(valid_durations),
|
|
|
len([d for d in valid_durations if d > np.mean(valid_durations) * 1.5])
|
|
|
]
|
|
|
|
|
|
|
|
|
while len(features) < self.config.prosody_dim:
|
|
|
features.append(0.0)
|
|
|
|
|
|
return features[:self.config.prosody_dim]
|
|
|
|
|
|
def _print_class_distribution(self):
|
|
|
"""Print class distribution"""
|
|
|
label_counts = Counter(sample["labels"].item() for sample in self.samples)
|
|
|
reverse_mapping = {v: k for k, v in self.aphasia_types_mapping.items()}
|
|
|
|
|
|
log_message("\nClass Distribution:")
|
|
|
for label_id, count in sorted(label_counts.items()):
|
|
|
class_name = reverse_mapping.get(label_id, f"Unknown_{label_id}")
|
|
|
log_message(f" {class_name}: {count} samples")
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.samples)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
return self.samples[idx]
|
|
|
|
|
|
|
|
|
def stable_collate_fn(batch):
|
|
|
"""Stable data collation"""
|
|
|
if not batch or batch[0] is None:
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
max_length = batch[0]["input_ids"].size(0)
|
|
|
|
|
|
collated_batch = {
|
|
|
"input_ids": torch.stack([item["input_ids"] for item in batch]),
|
|
|
"attention_mask": torch.stack([item["attention_mask"] for item in batch]),
|
|
|
"labels": torch.stack([item["labels"] for item in batch]),
|
|
|
"sentence_ids": [item.get("sentence_id", "N/A") for item in batch],
|
|
|
"word_pos_ids": torch.stack([item.get("word_pos_ids", torch.zeros(max_length, dtype=torch.long)) for item in batch]),
|
|
|
"word_grammar_ids": torch.stack([item.get("word_grammar_ids", torch.zeros(max_length, 3, dtype=torch.long)) for item in batch]),
|
|
|
"word_durations": torch.stack([item.get("word_durations", torch.zeros(max_length, dtype=torch.float)) for item in batch]),
|
|
|
"prosody_features": torch.stack([item.get("prosody_features", torch.zeros(max_length, 32, dtype=torch.float)) for item in batch])
|
|
|
}
|
|
|
return collated_batch
|
|
|
except Exception as e:
|
|
|
log_message(f"Collation error: {e}")
|
|
|
return None
|
|
|
|
|
|
|
|
|
class AdaptiveTrainingCallback(TrainerCallback):
|
|
|
"""Enhanced training callback with adaptive learning rate and comprehensive tracking"""
|
|
|
def __init__(self, config: ModelConfig, patience=5, min_delta=0.8):
|
|
|
self.config = config
|
|
|
self.patience = patience
|
|
|
self.min_delta = min_delta
|
|
|
self.best_metric = float('-inf')
|
|
|
self.patience_counter = 0
|
|
|
|
|
|
|
|
|
self.lr_scheduler = None
|
|
|
|
|
|
|
|
|
self.history_tracker = TrainingHistoryTracker()
|
|
|
|
|
|
|
|
|
self.current_train_metrics = {}
|
|
|
self.current_eval_metrics = {}
|
|
|
|
|
|
def on_train_begin(self, args, state, control, **kwargs):
|
|
|
"""Initialize learning rate scheduler"""
|
|
|
if self.config.adaptive_lr:
|
|
|
model = kwargs.get('model')
|
|
|
optimizer = kwargs.get('optimizer')
|
|
|
if optimizer and model:
|
|
|
total_steps = state.max_steps if state.max_steps > 0 else len(kwargs.get('train_dataloader', [])) * args.num_train_epochs
|
|
|
self.lr_scheduler = AdaptiveLearningRateScheduler(optimizer, self.config, total_steps)
|
|
|
log_message("Adaptive learning rate scheduler initialized")
|
|
|
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
|
|
"""Capture training metrics"""
|
|
|
if logs:
|
|
|
|
|
|
if 'train_loss' in logs:
|
|
|
self.current_train_metrics['loss'] = logs['train_loss']
|
|
|
if 'learning_rate' in logs:
|
|
|
self.current_train_metrics['lr'] = logs['learning_rate']
|
|
|
|
|
|
def on_evaluate(self, args, state, control, logs=None, **kwargs):
|
|
|
"""Handle evaluation and learning rate adjustment"""
|
|
|
if logs is not None:
|
|
|
current_metric = logs.get('eval_f1', 0)
|
|
|
current_loss = logs.get('eval_loss', float('inf'))
|
|
|
current_acc = logs.get('eval_accuracy', 0)
|
|
|
|
|
|
|
|
|
self.current_eval_metrics = {
|
|
|
'loss': current_loss,
|
|
|
'f1': current_metric,
|
|
|
'accuracy': current_acc,
|
|
|
'precision': logs.get('eval_precision_macro', 0),
|
|
|
'recall': logs.get('eval_recall_macro', 0)
|
|
|
}
|
|
|
|
|
|
|
|
|
epoch_metrics = {
|
|
|
'train_loss': self.current_train_metrics.get('loss', 0),
|
|
|
'eval_loss': current_loss,
|
|
|
'train_accuracy': 0,
|
|
|
'eval_accuracy': current_acc,
|
|
|
'train_f1': 0,
|
|
|
'eval_f1': current_metric,
|
|
|
'learning_rate': self.current_train_metrics.get('lr', self.config.learning_rate),
|
|
|
'train_precision': 0,
|
|
|
'eval_precision': logs.get('eval_precision_macro', 0),
|
|
|
'train_recall': 0,
|
|
|
'eval_recall': logs.get('eval_recall_macro', 0)
|
|
|
}
|
|
|
|
|
|
self.history_tracker.update(state.epoch, epoch_metrics)
|
|
|
|
|
|
|
|
|
if self.lr_scheduler and self.config.adaptive_lr:
|
|
|
new_lr = self.lr_scheduler.adaptive_lr_calculation(current_loss, current_metric, current_acc)
|
|
|
if current_acc > 0.84:
|
|
|
log_message(f"Target accuracy reached ({current_acc:.2%}) → stopping and saving model")
|
|
|
control.should_save = True
|
|
|
control.should_training_stop = True
|
|
|
return control
|
|
|
|
|
|
if current_metric > self.best_metric + self.min_delta:
|
|
|
self.best_metric = current_metric
|
|
|
self.patience_counter = 0
|
|
|
log_message(f"New best F1 score: {current_metric:.4f}")
|
|
|
else:
|
|
|
self.patience_counter += 1
|
|
|
log_message(f"No improvement for {self.patience_counter} evaluations")
|
|
|
|
|
|
if self.patience_counter >= self.patience:
|
|
|
log_message("Early stopping triggered")
|
|
|
control.should_training_stop = True
|
|
|
|
|
|
clear_memory()
|
|
|
|
|
|
def on_train_end(self, args, state, control, **kwargs):
|
|
|
"""Save training history at the end"""
|
|
|
output_dir = args.output_dir
|
|
|
self.history_tracker.save_history(output_dir)
|
|
|
self.history_tracker.plot_training_curves(output_dir)
|
|
|
log_message("Training history and curves saved")
|
|
|
|
|
|
|
|
|
def compute_comprehensive_metrics(pred):
|
|
|
"""Compute comprehensive evaluation metrics"""
|
|
|
predictions = pred.predictions[0] if isinstance(pred.predictions, tuple) else pred.predictions
|
|
|
labels = pred.label_ids
|
|
|
|
|
|
preds = np.argmax(predictions, axis=1)
|
|
|
|
|
|
acc = accuracy_score(labels, preds)
|
|
|
f1_macro = f1_score(labels, preds, average='macro', zero_division=0)
|
|
|
f1_weighted = f1_score(labels, preds, average='weighted', zero_division=0)
|
|
|
precision_macro = precision_score(labels, preds, average='macro', zero_division=0)
|
|
|
recall_macro = recall_score(labels, preds, average='macro', zero_division=0)
|
|
|
|
|
|
|
|
|
f1_per_class = f1_score(labels, preds, average=None, zero_division=0)
|
|
|
precision_per_class = precision_score(labels, preds, average=None, zero_division=0)
|
|
|
recall_per_class = recall_score(labels, preds, average=None, zero_division=0)
|
|
|
|
|
|
return {
|
|
|
"accuracy": acc,
|
|
|
"f1": f1_weighted,
|
|
|
"f1_macro": f1_macro,
|
|
|
"precision_macro": precision_macro,
|
|
|
"recall_macro": recall_macro,
|
|
|
"f1_std": np.std(f1_per_class),
|
|
|
"precision_std": np.std(precision_per_class),
|
|
|
"recall_std": np.std(recall_per_class)
|
|
|
}
|
|
|
|
|
|
|
|
|
def generate_comprehensive_reports(trainer, eval_dataset, aphasia_types_mapping, tokenizer, output_dir):
|
|
|
"""Generate comprehensive analysis reports and visualizations"""
|
|
|
log_message("Generating comprehensive reports...")
|
|
|
|
|
|
model = trainer.model
|
|
|
if hasattr(model, 'module'):
|
|
|
model = model.module
|
|
|
|
|
|
model.eval()
|
|
|
device = next(model.parameters()).device
|
|
|
|
|
|
predictions = []
|
|
|
true_labels = []
|
|
|
sentence_ids = []
|
|
|
severity_preds = []
|
|
|
fluency_preds = []
|
|
|
prediction_probs = []
|
|
|
|
|
|
|
|
|
dataloader = DataLoader(eval_dataset, batch_size=8, collate_fn=stable_collate_fn)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for batch_idx, batch in enumerate(dataloader):
|
|
|
if batch is None:
|
|
|
continue
|
|
|
|
|
|
|
|
|
for key in ['input_ids', 'attention_mask', 'word_pos_ids',
|
|
|
'word_grammar_ids', 'word_durations', 'labels', 'prosody_features']:
|
|
|
if key in batch:
|
|
|
batch[key] = batch[key].to(device)
|
|
|
|
|
|
outputs = model(**batch)
|
|
|
|
|
|
logits = outputs["logits"]
|
|
|
probs = F.softmax(logits, dim=1)
|
|
|
preds = torch.argmax(logits, dim=1).cpu().numpy()
|
|
|
|
|
|
predictions.extend(preds)
|
|
|
true_labels.extend(batch["labels"].cpu().numpy())
|
|
|
sentence_ids.extend(batch["sentence_ids"])
|
|
|
severity_preds.extend(outputs["severity_pred"].cpu().numpy())
|
|
|
fluency_preds.extend(outputs["fluency_pred"].cpu().numpy())
|
|
|
prediction_probs.extend(probs.cpu().numpy())
|
|
|
|
|
|
|
|
|
reverse_mapping = {v: k for k, v in aphasia_types_mapping.items()}
|
|
|
|
|
|
|
|
|
log_message("=== DETAILED PREDICTIONS (First 20) ===")
|
|
|
for i in range(min(20, len(predictions))):
|
|
|
true_type = reverse_mapping.get(true_labels[i], 'Unknown')
|
|
|
pred_type = reverse_mapping.get(predictions[i], 'Unknown')
|
|
|
severity_level = np.argmax(severity_preds[i])
|
|
|
fluency_score = fluency_preds[i][0] if isinstance(fluency_preds[i], np.ndarray) else fluency_preds[i]
|
|
|
confidence = np.max(prediction_probs[i])
|
|
|
|
|
|
log_message(f"ID: {sentence_ids[i]} | True: {true_type} | Pred: {pred_type} | "
|
|
|
f"Confidence: {confidence:.3f} | Severity: {severity_level} | Fluency: {fluency_score:.3f}")
|
|
|
|
|
|
|
|
|
cm = confusion_matrix(true_labels, predictions)
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(14, 12))
|
|
|
|
|
|
|
|
|
cm_percentage = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
|
|
|
|
|
|
|
|
|
annotations = np.empty_like(cm, dtype=object)
|
|
|
for i in range(cm.shape[0]):
|
|
|
for j in range(cm.shape[1]):
|
|
|
annotations[i, j] = f'{cm[i, j]}\n({cm_percentage[i, j]:.1f}%)'
|
|
|
|
|
|
sns.heatmap(cm, annot=annotations, fmt='', cmap="Blues",
|
|
|
xticklabels=list(aphasia_types_mapping.keys()),
|
|
|
yticklabels=list(aphasia_types_mapping.keys()),
|
|
|
cbar_kws={'label': 'Count'})
|
|
|
|
|
|
plt.xlabel("Predicted Label", fontsize=12, fontweight='bold')
|
|
|
plt.ylabel("True Label", fontsize=12, fontweight='bold')
|
|
|
plt.title("Enhanced Confusion Matrix\n(Count and Percentage)", fontsize=14, fontweight='bold')
|
|
|
plt.xticks(rotation=45, ha='right')
|
|
|
plt.yticks(rotation=0)
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(os.path.join(output_dir, "enhanced_confusion_matrix.png"), dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
all_label_ids = list(aphasia_types_mapping.values())
|
|
|
report_dict = classification_report(
|
|
|
true_labels,
|
|
|
predictions,
|
|
|
labels=all_label_ids,
|
|
|
target_names=list(aphasia_types_mapping.keys()),
|
|
|
output_dict=True,
|
|
|
zero_division=0
|
|
|
)
|
|
|
|
|
|
df_report = pd.DataFrame(report_dict).transpose()
|
|
|
df_report.to_csv(os.path.join(output_dir, "comprehensive_classification_report.csv"))
|
|
|
|
|
|
|
|
|
class_names = list(aphasia_types_mapping.keys())
|
|
|
metrics_data = []
|
|
|
|
|
|
for i, class_name in enumerate(class_names):
|
|
|
if class_name in report_dict:
|
|
|
metrics_data.append({
|
|
|
'Class': class_name,
|
|
|
'Precision': report_dict[class_name]['precision'],
|
|
|
'Recall': report_dict[class_name]['recall'],
|
|
|
'F1-Score': report_dict[class_name]['f1-score'],
|
|
|
'Support': report_dict[class_name]['support']
|
|
|
})
|
|
|
|
|
|
df_metrics = pd.DataFrame(metrics_data)
|
|
|
df_metrics.to_csv(os.path.join(output_dir, "per_class_metrics.csv"), index=False)
|
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
|
|
|
|
|
|
|
|
axes[0, 0].bar(df_metrics['Class'], df_metrics['Precision'], color='skyblue', alpha=0.8)
|
|
|
axes[0, 0].set_title('Precision by Class', fontweight='bold')
|
|
|
axes[0, 0].set_ylabel('Precision')
|
|
|
axes[0, 0].tick_params(axis='x', rotation=45)
|
|
|
axes[0, 0].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
axes[0, 1].bar(df_metrics['Class'], df_metrics['Recall'], color='lightcoral', alpha=0.8)
|
|
|
axes[0, 1].set_title('Recall by Class', fontweight='bold')
|
|
|
axes[0, 1].set_ylabel('Recall')
|
|
|
axes[0, 1].tick_params(axis='x', rotation=45)
|
|
|
axes[0, 1].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
axes[1, 0].bar(df_metrics['Class'], df_metrics['F1-Score'], color='lightgreen', alpha=0.8)
|
|
|
axes[1, 0].set_title('F1-Score by Class', fontweight='bold')
|
|
|
axes[1, 0].set_ylabel('F1-Score')
|
|
|
axes[1, 0].tick_params(axis='x', rotation=45)
|
|
|
axes[1, 0].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
axes[1, 1].bar(df_metrics['Class'], df_metrics['Support'], color='gold', alpha=0.8)
|
|
|
axes[1, 1].set_title('Support by Class', fontweight='bold')
|
|
|
axes[1, 1].set_ylabel('Support (Number of Samples)')
|
|
|
axes[1, 1].tick_params(axis='x', rotation=45)
|
|
|
axes[1, 1].grid(True, alpha=0.3)
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(os.path.join(output_dir, "per_class_performance.png"), dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
confidences = [np.max(prob) for prob in prediction_probs]
|
|
|
correct_predictions = [pred == true for pred, true in zip(predictions, true_labels)]
|
|
|
|
|
|
plt.figure(figsize=(12, 8))
|
|
|
|
|
|
|
|
|
correct_confidences = [conf for conf, correct in zip(confidences, correct_predictions) if correct]
|
|
|
incorrect_confidences = [conf for conf, correct in zip(confidences, correct_predictions) if not correct]
|
|
|
|
|
|
plt.hist(correct_confidences, bins=30, alpha=0.7, label='Correct Predictions', color='green', density=True)
|
|
|
plt.hist(incorrect_confidences, bins=30, alpha=0.7, label='Incorrect Predictions', color='red', density=True)
|
|
|
|
|
|
plt.xlabel('Prediction Confidence', fontsize=12)
|
|
|
plt.ylabel('Density', fontsize=12)
|
|
|
plt.title('Distribution of Prediction Confidence', fontsize=14, fontweight='bold')
|
|
|
plt.legend()
|
|
|
plt.grid(True, alpha=0.3)
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(os.path.join(output_dir, "confidence_distribution.png"), dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
log_message("=== FEATURE ANALYSIS ===")
|
|
|
avg_severity = np.mean(severity_preds, axis=0)
|
|
|
avg_fluency = np.mean(fluency_preds)
|
|
|
std_fluency = np.std(fluency_preds)
|
|
|
|
|
|
log_message(f"Average Severity Distribution: {avg_severity}")
|
|
|
log_message(f"Average Fluency Score: {avg_fluency:.3f} ± {std_fluency:.3f}")
|
|
|
|
|
|
|
|
|
results_df = pd.DataFrame({
|
|
|
'sentence_id': sentence_ids,
|
|
|
'true_label': [reverse_mapping[label] for label in true_labels],
|
|
|
'predicted_label': [reverse_mapping[pred] for pred in predictions],
|
|
|
'prediction_confidence': confidences,
|
|
|
'correct_prediction': correct_predictions,
|
|
|
'severity_level': [np.argmax(severity) for severity in severity_preds],
|
|
|
'fluency_score': [fluency[0] if isinstance(fluency, np.ndarray) else fluency for fluency in fluency_preds]
|
|
|
})
|
|
|
|
|
|
|
|
|
for i, class_name in enumerate(aphasia_types_mapping.keys()):
|
|
|
results_df[f'prob_{class_name}'] = [prob[i] for prob in prediction_probs]
|
|
|
|
|
|
results_df.to_csv(os.path.join(output_dir, "comprehensive_results.csv"), index=False)
|
|
|
|
|
|
|
|
|
summary_stats = {
|
|
|
'Overall Accuracy': accuracy_score(true_labels, predictions),
|
|
|
'Macro F1': f1_score(true_labels, predictions, average='macro'),
|
|
|
'Weighted F1': f1_score(true_labels, predictions, average='weighted'),
|
|
|
'Macro Precision': precision_score(true_labels, predictions, average='macro'),
|
|
|
'Macro Recall': recall_score(true_labels, predictions, average='macro'),
|
|
|
'Average Confidence': np.mean(confidences),
|
|
|
'Confidence Std': np.std(confidences),
|
|
|
'Average Severity': avg_severity.tolist(),
|
|
|
'Average Fluency': avg_fluency,
|
|
|
'Fluency Std': std_fluency
|
|
|
}
|
|
|
|
|
|
serializable_summary = {
|
|
|
k: float(v) if isinstance(v, (np.floating, np.integer)) else v
|
|
|
for k, v in summary_stats.items()
|
|
|
}
|
|
|
with open(os.path.join(output_dir, "summary_statistics.json"), "w") as f:
|
|
|
json.dump(serializable_summary, f, indent=2)
|
|
|
|
|
|
log_message("Comprehensive Classification Report:")
|
|
|
log_message(df_report.to_string())
|
|
|
log_message(f"Comprehensive results saved to {output_dir}")
|
|
|
|
|
|
return results_df, df_report, summary_stats
|
|
|
|
|
|
|
|
|
def train_adaptive_model(json_file: str, output_dir: str = "./adaptive_aphasia_model"):
|
|
|
"""Main training function with adaptive learning rate"""
|
|
|
|
|
|
log_message("Starting Adaptive Aphasia Classification Training")
|
|
|
log_message("=" * 60)
|
|
|
|
|
|
|
|
|
config = ModelConfig()
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
log_message(f"Using device: {device}")
|
|
|
|
|
|
|
|
|
log_message("Loading dataset...")
|
|
|
with open(json_file, "r", encoding="utf-8") as f:
|
|
|
dataset_json = json.load(f)
|
|
|
|
|
|
sentences = dataset_json.get("sentences", [])
|
|
|
|
|
|
|
|
|
for item in sentences:
|
|
|
if "aphasia_type" in item:
|
|
|
item["aphasia_type"] = normalize_type(item["aphasia_type"])
|
|
|
|
|
|
|
|
|
aphasia_types_mapping = {
|
|
|
"BROCA": 0,
|
|
|
"TRANSMOTOR": 1,
|
|
|
"NOTAPHASICBYWAB": 2,
|
|
|
"CONDUCTION": 3,
|
|
|
"WERNICKE": 4,
|
|
|
"ANOMIC": 5,
|
|
|
"GLOBAL": 6,
|
|
|
"ISOLATION": 7,
|
|
|
"TRANSSENSORY": 8
|
|
|
}
|
|
|
|
|
|
log_message(f"Aphasia Types Mapping: {aphasia_types_mapping}")
|
|
|
|
|
|
num_labels = len(aphasia_types_mapping)
|
|
|
log_message(f"Number of labels: {num_labels}")
|
|
|
|
|
|
|
|
|
filtered_sentences = []
|
|
|
for item in sentences:
|
|
|
aphasia_type = item.get("aphasia_type", "")
|
|
|
if aphasia_type in aphasia_types_mapping:
|
|
|
filtered_sentences.append(item)
|
|
|
else:
|
|
|
log_message(f"Excluding sentence with invalid type: {aphasia_type}")
|
|
|
|
|
|
log_message(f"Filtered dataset: {len(filtered_sentences)} sentences")
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
random.shuffle(filtered_sentences)
|
|
|
dataset_all = StableAphasiaDataset(
|
|
|
filtered_sentences, tokenizer, aphasia_types_mapping, config
|
|
|
)
|
|
|
|
|
|
|
|
|
total_samples = len(dataset_all)
|
|
|
train_size = int(0.8 * total_samples)
|
|
|
eval_size = total_samples - train_size
|
|
|
|
|
|
train_dataset, eval_dataset = torch.utils.data.random_split(
|
|
|
dataset_all, [train_size, eval_size]
|
|
|
)
|
|
|
|
|
|
log_message(f"Train size: {train_size}, Eval size: {eval_size}")
|
|
|
|
|
|
|
|
|
train_labels = [dataset_all.samples[idx]["labels"].item() for idx in train_dataset.indices]
|
|
|
label_counts = Counter(train_labels)
|
|
|
sample_weights = [1.0 / label_counts[label] for label in train_labels]
|
|
|
sampler = WeightedRandomSampler(
|
|
|
weights=sample_weights,
|
|
|
num_samples=len(sample_weights),
|
|
|
replacement=True
|
|
|
)
|
|
|
|
|
|
|
|
|
def model_init():
|
|
|
model = StableAphasiaClassifier(config, num_labels)
|
|
|
model.bert.resize_token_embeddings(len(tokenizer))
|
|
|
return model.to(device)
|
|
|
|
|
|
|
|
|
training_args = TrainingArguments(
|
|
|
output_dir=output_dir,
|
|
|
eval_strategy="epoch",
|
|
|
save_strategy="epoch",
|
|
|
learning_rate=config.learning_rate,
|
|
|
per_device_train_batch_size=config.batch_size,
|
|
|
per_device_eval_batch_size=config.batch_size,
|
|
|
num_train_epochs=config.num_epochs,
|
|
|
weight_decay=config.weight_decay,
|
|
|
warmup_ratio=config.warmup_ratio,
|
|
|
logging_strategy="steps",
|
|
|
logging_steps=50,
|
|
|
seed=42,
|
|
|
dataloader_num_workers=0,
|
|
|
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
|
|
max_grad_norm=1.0,
|
|
|
fp16=False,
|
|
|
dataloader_drop_last=True,
|
|
|
report_to=None,
|
|
|
load_best_model_at_end=True,
|
|
|
metric_for_best_model="eval_f1",
|
|
|
greater_is_better=True,
|
|
|
save_total_limit=3,
|
|
|
remove_unused_columns=False,
|
|
|
)
|
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
model_init=model_init,
|
|
|
args=training_args,
|
|
|
train_dataset=train_dataset,
|
|
|
eval_dataset=eval_dataset,
|
|
|
compute_metrics=compute_comprehensive_metrics,
|
|
|
data_collator=stable_collate_fn,
|
|
|
callbacks=[AdaptiveTrainingCallback(config, patience=5, min_delta=0.8)]
|
|
|
)
|
|
|
|
|
|
|
|
|
log_message("Starting adaptive training...")
|
|
|
try:
|
|
|
trainer.train()
|
|
|
log_message("Training completed successfully!")
|
|
|
except Exception as e:
|
|
|
log_message(f"Training error: {str(e)}")
|
|
|
import traceback
|
|
|
log_message(traceback.format_exc())
|
|
|
raise
|
|
|
|
|
|
|
|
|
log_message("Starting final evaluation...")
|
|
|
eval_results = trainer.evaluate()
|
|
|
log_message(f"Final evaluation results: {eval_results}")
|
|
|
|
|
|
|
|
|
results_df, report_df, summary_stats = generate_comprehensive_reports(
|
|
|
trainer, eval_dataset, aphasia_types_mapping, tokenizer, output_dir
|
|
|
)
|
|
|
|
|
|
|
|
|
model_to_save = trainer.model
|
|
|
if hasattr(model_to_save, 'module'):
|
|
|
model_to_save = model_to_save.module
|
|
|
|
|
|
torch.save(model_to_save.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
|
|
|
tokenizer.save_pretrained(output_dir)
|
|
|
|
|
|
|
|
|
config_dict = {
|
|
|
"model_name": config.model_name,
|
|
|
"num_labels": num_labels,
|
|
|
"aphasia_types_mapping": aphasia_types_mapping,
|
|
|
"training_args": training_args.to_dict(),
|
|
|
"adaptive_lr_config": {
|
|
|
"adaptive_lr": config.adaptive_lr,
|
|
|
"lr_patience": config.lr_patience,
|
|
|
"lr_factor": config.lr_factor,
|
|
|
"lr_increase_factor": config.lr_increase_factor,
|
|
|
"min_lr": config.min_lr,
|
|
|
"max_lr": config.max_lr,
|
|
|
"oscillation_amplitude": config.oscillation_amplitude
|
|
|
}
|
|
|
}
|
|
|
|
|
|
with open(os.path.join(output_dir, "config.json"), "w") as f:
|
|
|
json.dump(config_dict, f, indent=2)
|
|
|
|
|
|
log_message(f"Adaptive model and comprehensive reports saved to {output_dir}")
|
|
|
clear_memory()
|
|
|
|
|
|
return trainer, eval_results, results_df
|
|
|
|
|
|
|
|
|
def train_adaptive_cross_validation(json_file: str, output_dir: str = "./adaptive_cv_results", n_folds: int = 5):
|
|
|
"""Cross-validation training with adaptive learning rate"""
|
|
|
log_message("Starting Adaptive Cross-Validation Training")
|
|
|
|
|
|
config = ModelConfig()
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
with open(json_file, "r", encoding="utf-8") as f:
|
|
|
dataset_json = json.load(f)
|
|
|
|
|
|
sentences = dataset_json.get("sentences", [])
|
|
|
|
|
|
|
|
|
for item in sentences:
|
|
|
if "aphasia_type" in item:
|
|
|
item["aphasia_type"] = normalize_type(item["aphasia_type"])
|
|
|
|
|
|
aphasia_types_mapping = {
|
|
|
"BROCA": 0, "TRANSMOTOR": 1, "NOTAPHASICBYWAB": 2,
|
|
|
"CONDUCTION": 3, "WERNICKE": 4, "ANOMIC": 5,
|
|
|
"GLOBAL": 6, "ISOLATION": 7, "TRANSSENSORY": 8
|
|
|
}
|
|
|
|
|
|
filtered_sentences = [s for s in sentences if s.get("aphasia_type") in aphasia_types_mapping]
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
full_dataset = StableAphasiaDataset(
|
|
|
filtered_sentences, tokenizer, aphasia_types_mapping, config
|
|
|
)
|
|
|
|
|
|
|
|
|
sample_labels = [sample["labels"].item() for sample in full_dataset.samples]
|
|
|
|
|
|
|
|
|
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)
|
|
|
fold_results = []
|
|
|
all_predictions = []
|
|
|
all_true_labels = []
|
|
|
|
|
|
for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(sample_labels)), sample_labels)):
|
|
|
log_message(f"\n=== Fold {fold + 1}/{n_folds} ===")
|
|
|
|
|
|
train_subset = Subset(full_dataset, train_idx)
|
|
|
val_subset = Subset(full_dataset, val_idx)
|
|
|
|
|
|
|
|
|
fold_trainer, fold_results_dict, fold_predictions = train_adaptive_single_fold(
|
|
|
train_subset, val_subset, config, aphasia_types_mapping,
|
|
|
tokenizer, fold, output_dir
|
|
|
)
|
|
|
|
|
|
fold_results.append({
|
|
|
'fold': fold + 1,
|
|
|
**fold_results_dict
|
|
|
})
|
|
|
|
|
|
|
|
|
all_predictions.extend(fold_predictions['predictions'])
|
|
|
all_true_labels.extend(fold_predictions['true_labels'])
|
|
|
|
|
|
clear_memory()
|
|
|
|
|
|
|
|
|
results_df = pd.DataFrame(fold_results)
|
|
|
results_df.to_csv(os.path.join(output_dir, "adaptive_cv_summary.csv"), index=False)
|
|
|
|
|
|
|
|
|
cv_summary = {
|
|
|
'mean_accuracy': results_df['accuracy'].mean(),
|
|
|
'std_accuracy': results_df['accuracy'].std(),
|
|
|
'mean_f1': results_df['f1'].mean(),
|
|
|
'std_f1': results_df['f1'].std(),
|
|
|
'mean_f1_macro': results_df['f1_macro'].mean(),
|
|
|
'std_f1_macro': results_df['f1_macro'].std(),
|
|
|
'mean_precision': results_df['precision_macro'].mean(),
|
|
|
'std_precision': results_df['precision_macro'].std(),
|
|
|
'mean_recall': results_df['recall_macro'].mean(),
|
|
|
'std_recall': results_df['recall_macro'].std()
|
|
|
}
|
|
|
|
|
|
with open(os.path.join(output_dir, "cv_statistics.json"), "w") as f:
|
|
|
json.dump(cv_summary, f, indent=2)
|
|
|
|
|
|
|
|
|
overall_cm = confusion_matrix(all_true_labels, all_predictions)
|
|
|
|
|
|
plt.figure(figsize=(12, 10))
|
|
|
sns.heatmap(overall_cm, annot=True, fmt="d", cmap="Blues",
|
|
|
xticklabels=list(aphasia_types_mapping.keys()),
|
|
|
yticklabels=list(aphasia_types_mapping.keys()))
|
|
|
plt.xlabel("Predicted Label")
|
|
|
plt.ylabel("True Label")
|
|
|
plt.title("Overall Confusion Matrix (All Folds)")
|
|
|
plt.xticks(rotation=45)
|
|
|
plt.yticks(rotation=0)
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(os.path.join(output_dir, "overall_confusion_matrix.png"), dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
|
|
|
|
|
|
|
|
|
axes[0, 0].bar(range(1, n_folds + 1), results_df['accuracy'], color='skyblue', alpha=0.8)
|
|
|
axes[0, 0].axhline(y=results_df['accuracy'].mean(), color='red', linestyle='--',
|
|
|
label=f'Mean: {results_df["accuracy"].mean():.3f}')
|
|
|
axes[0, 0].set_title('Accuracy Across Folds')
|
|
|
axes[0, 0].set_xlabel('Fold')
|
|
|
axes[0, 0].set_ylabel('Accuracy')
|
|
|
axes[0, 0].legend()
|
|
|
axes[0, 0].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
axes[0, 1].bar(range(1, n_folds + 1), results_df['f1'], color='lightgreen', alpha=0.8)
|
|
|
axes[0, 1].axhline(y=results_df['f1'].mean(), color='red', linestyle='--',
|
|
|
label=f'Mean: {results_df["f1"].mean():.3f}')
|
|
|
axes[0, 1].set_title('F1 Score Across Folds')
|
|
|
axes[0, 1].set_xlabel('Fold')
|
|
|
axes[0, 1].set_ylabel('F1 Score')
|
|
|
axes[0, 1].legend()
|
|
|
axes[0, 1].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
axes[1, 0].bar(range(1, n_folds + 1), results_df['precision_macro'], color='coral', alpha=0.8)
|
|
|
axes[1, 0].axhline(y=results_df['precision_macro'].mean(), color='red', linestyle='--',
|
|
|
label=f'Mean: {results_df["precision_macro"].mean():.3f}')
|
|
|
axes[1, 0].set_title('Precision Across Folds')
|
|
|
axes[1, 0].set_xlabel('Fold')
|
|
|
axes[1, 0].set_ylabel('Precision')
|
|
|
axes[1, 0].legend()
|
|
|
axes[1, 0].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
axes[1, 1].bar(range(1, n_folds + 1), results_df['recall_macro'], color='gold', alpha=0.8)
|
|
|
axes[1, 1].axhline(y=results_df['recall_macro'].mean(), color='red', linestyle='--',
|
|
|
label=f'Mean: {results_df["recall_macro"].mean():.3f}')
|
|
|
axes[1, 1].set_title('Recall Across Folds')
|
|
|
axes[1, 1].set_xlabel('Fold')
|
|
|
axes[1, 1].set_ylabel('Recall')
|
|
|
axes[1, 1].legend()
|
|
|
axes[1, 1].grid(True, alpha=0.3)
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.savefig(os.path.join(output_dir, "cv_performance_comparison.png"), dpi=300, bbox_inches='tight')
|
|
|
plt.close()
|
|
|
|
|
|
log_message("\n=== Adaptive Cross-Validation Summary ===")
|
|
|
log_message(results_df.to_string(index=False))
|
|
|
|
|
|
|
|
|
log_message(f"\nMean F1: {results_df['f1'].mean():.4f} ± {results_df['f1'].std():.4f}")
|
|
|
log_message(f"Mean Accuracy: {results_df['accuracy'].mean():.4f} ± {results_df['accuracy'].std():.4f}")
|
|
|
log_message(f"Mean F1 Macro: {results_df['f1_macro'].mean():.4f} ± {results_df['f1_macro'].std():.4f}")
|
|
|
|
|
|
return results_df, cv_summary
|
|
|
|
|
|
def train_adaptive_single_fold(train_dataset, val_dataset, config, aphasia_types_mapping,
|
|
|
tokenizer, fold, output_dir):
|
|
|
"""Train a single fold with adaptive learning rate"""
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
num_labels = len(aphasia_types_mapping)
|
|
|
|
|
|
|
|
|
train_labels = [train_dataset[i]["labels"].item() for i in range(len(train_dataset))]
|
|
|
label_counts = Counter(train_labels)
|
|
|
sample_weights = [1.0 / label_counts[label] for label in train_labels]
|
|
|
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
|
|
|
|
|
|
|
|
def model_init():
|
|
|
model = StableAphasiaClassifier(config, num_labels)
|
|
|
model.bert.resize_token_embeddings(len(tokenizer))
|
|
|
return model.to(device)
|
|
|
|
|
|
|
|
|
fold_output_dir = os.path.join(output_dir, f"fold_{fold}")
|
|
|
os.makedirs(fold_output_dir, exist_ok=True)
|
|
|
|
|
|
training_args = TrainingArguments(
|
|
|
output_dir=fold_output_dir,
|
|
|
eval_strategy="epoch",
|
|
|
save_strategy="epoch",
|
|
|
learning_rate=config.learning_rate,
|
|
|
per_device_train_batch_size=config.batch_size,
|
|
|
per_device_eval_batch_size=config.batch_size,
|
|
|
num_train_epochs=config.num_epochs,
|
|
|
weight_decay=config.weight_decay,
|
|
|
warmup_ratio=config.warmup_ratio,
|
|
|
logging_steps=50,
|
|
|
seed=42,
|
|
|
dataloader_num_workers=0,
|
|
|
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
|
|
max_grad_norm=1.0,
|
|
|
fp16=False,
|
|
|
dataloader_drop_last=True,
|
|
|
report_to=None,
|
|
|
load_best_model_at_end=True,
|
|
|
metric_for_best_model="eval_f1",
|
|
|
greater_is_better=True,
|
|
|
save_total_limit=1,
|
|
|
remove_unused_columns=False,
|
|
|
)
|
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
model_init=model_init,
|
|
|
args=training_args,
|
|
|
train_dataset=train_dataset,
|
|
|
eval_dataset=val_dataset,
|
|
|
compute_metrics=compute_comprehensive_metrics,
|
|
|
data_collator=stable_collate_fn,
|
|
|
callbacks=[AdaptiveTrainingCallback(config, patience=5, min_delta=0.8)]
|
|
|
)
|
|
|
|
|
|
|
|
|
trainer.train()
|
|
|
|
|
|
|
|
|
eval_results = trainer.evaluate()
|
|
|
|
|
|
|
|
|
predictions = trainer.predict(val_dataset)
|
|
|
pred_labels = np.argmax(predictions.predictions[0] if isinstance(predictions.predictions, tuple) else predictions.predictions, axis=1)
|
|
|
true_labels = predictions.label_ids
|
|
|
|
|
|
fold_predictions = {
|
|
|
'predictions': pred_labels.tolist(),
|
|
|
'true_labels': true_labels.tolist()
|
|
|
}
|
|
|
|
|
|
|
|
|
model_to_save = trainer.model
|
|
|
if hasattr(model_to_save, 'module'):
|
|
|
model_to_save = model_to_save.module
|
|
|
|
|
|
torch.save(model_to_save.state_dict(), os.path.join(fold_output_dir, "pytorch_model.bin"))
|
|
|
|
|
|
return trainer, eval_results, fold_predictions
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import argparse
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Adaptive Learning Rate Aphasia Classification Training")
|
|
|
parser.add_argument("--output_dir", type=str, default="./adaptive_aphasia_model", help="Output directory")
|
|
|
parser.add_argument("--cross_validation", action="store_true", help="Use cross-validation")
|
|
|
parser.add_argument("--n_folds", type=int, default=5, help="Number of CV folds")
|
|
|
parser.add_argument("--json_file", type=str, default=json_file, help="Path to JSON dataset file")
|
|
|
parser.add_argument("--learning_rate", type=float, default=5e-4, help="Initial learning rate")
|
|
|
parser.add_argument("--batch_size", type=int, default=24, help="Batch size")
|
|
|
parser.add_argument("--num_epochs", type=int, default=3, help="Number of epochs")
|
|
|
parser.add_argument("--adaptive_lr", action="store_true", default=True, help="Use adaptive learning rate")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
config = ModelConfig()
|
|
|
config.learning_rate = args.learning_rate
|
|
|
config.batch_size = args.batch_size
|
|
|
config.num_epochs = args.num_epochs
|
|
|
config.adaptive_lr = args.adaptive_lr
|
|
|
|
|
|
try:
|
|
|
clear_memory()
|
|
|
|
|
|
log_message(f"Starting training with adaptive_lr={config.adaptive_lr}")
|
|
|
log_message(f"Config: lr={config.learning_rate}, batch_size={config.batch_size}, epochs={config.num_epochs}")
|
|
|
|
|
|
if args.cross_validation:
|
|
|
results_df, cv_summary = train_adaptive_cross_validation(args.json_file, args.output_dir, args.n_folds)
|
|
|
log_message("Cross-validation training completed!")
|
|
|
else:
|
|
|
trainer, eval_results, results_df = train_adaptive_model(args.json_file, args.output_dir)
|
|
|
log_message("Single model training completed!")
|
|
|
|
|
|
log_message("All adaptive training completed successfully!")
|
|
|
|
|
|
except Exception as e:
|
|
|
log_message(f"Training failed: {str(e)}")
|
|
|
import traceback
|
|
|
log_message(traceback.format_exc())
|
|
|
finally:
|
|
|
clear_memory() |