import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import json import os from datetime import datetime # 设置中文字体 (尝试自动寻找可用字体) def set_chinese_font(): plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'PingFang SC', 'Heiti TC'] plt.rcParams['axes.unicode_minus'] = False def plot_data_distribution(dataset_dict, save_path=None): """ 绘制数据集中 Positive/Neutral/Negative 的分布饼图 """ set_chinese_font() # 统计数量 # 兼容 dataset_dict (DatasetDict) 或 dataset (Dataset) if hasattr(dataset_dict, 'keys') and 'train' in dataset_dict.keys(): ds = dataset_dict['train'] else: ds = dataset_dict # 统计数量 if 'label' in ds.features: train_labels = ds['label'] elif 'labels' in ds.features: train_labels = ds['labels'] else: # Fallback train_labels = [x.get('label', x.get('labels')) for x in ds] # 映射回字符串以便显示 id2label = {0: 'Negative (消极)', 1: 'Neutral (中性)', 2: 'Positive (积极)'} labels_str = [id2label.get(x, str(x)) for x in train_labels] df = pd.DataFrame({'Label': labels_str}) counts = df['Label'].value_counts() plt.figure(figsize=(10, 6)) plt.pie(counts, labels=counts.index, autopct='%1.1f%%', startangle=140, colors=sns.color_palette("pastel")) plt.title('训练集情感分布') plt.tight_layout() if save_path: print(f"Saving distribution plot to {save_path}...") plt.savefig(save_path) # plt.show() def plot_training_history(log_history, save_path=None): """ 根据 Trainer 的 log_history 绘制 Loss 和 Accuracy 曲线 """ set_chinese_font() if not log_history: print("没有可用的训练日志。") return df = pd.DataFrame(log_history) # 过滤掉没有 loss 或 eval_accuracy 的行 train_loss = df[df['loss'].notna()] eval_acc = df[df['eval_accuracy'].notna()] plt.figure(figsize=(14, 5)) # 1. Loss Curve plt.subplot(1, 2, 1) plt.plot(train_loss['epoch'], train_loss['loss'], label='Training Loss', color='salmon') if 'eval_loss' in df.columns: eval_loss = df[df['eval_loss'].notna()] plt.plot(eval_loss['epoch'], eval_loss['eval_loss'], label='Validation Loss', color='skyblue') plt.title('训练损失 (Loss) 曲线') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.grid(True, alpha=0.3) # 2. Accuracy Curve if not eval_acc.empty: plt.subplot(1, 2, 2) plt.plot(eval_acc['epoch'], eval_acc['eval_accuracy'], label='Validation Accuracy', color='lightgreen', marker='o') plt.title('验证集准确率 (Accuracy)') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.legend() plt.grid(True, alpha=0.3) # 确保目录存在 save_dir = os.path.join(Config.RESULTS_DIR, "images") if not os.path.exists(save_dir): os.makedirs(save_dir) plt.tight_layout() # 生成时间戳 string,例如: 2024-12-18_14-30-00 timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") # 默认保存路径 if save_path is None: save_path = os.path.join(save_dir, f"training_metrics_{timestamp}.png") print(f"Saving plot to {save_path}...") plt.savefig(save_path) # 也可以保存一份 JSON 或 TXT 格式的最终指标 if not eval_acc.empty: final_acc = eval_acc.iloc[-1]['eval_accuracy'] final_loss = eval_acc.iloc[-1]['eval_loss'] if 'eval_loss' in eval_acc.columns else "N/A" metrics_file = os.path.join(save_dir, f"metrics_{timestamp}.txt") with open(metrics_file, "w") as f: f.write(f"Timestamp: {timestamp}\n") f.write(f"Final Validation Accuracy: {final_acc:.4f}\n") f.write(f"Final Validation Loss: {final_loss}\n") f.write(f"Plot saved to: {os.path.basename(save_path)}\n") print(f"Saved metrics text to {metrics_file}") def load_and_plot_logs(log_dir): """ 从 checkpoint 目录加载 trainer_state.json 并绘图 """ json_path = os.path.join(log_dir, 'trainer_state.json') if not os.path.exists(json_path): print(f"未找到日志文件: {json_path}") return with open(json_path, 'r') as f: data = json.load(f) plot_training_history(data['log_history']) if __name__ == "__main__": import sys import os # Explicitly import os here if not globally sufficient or for clarity # 如果直接运行此脚本,解决相对导入问题 # 将上一级目录加入 sys.path project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(project_root) from src.config import Config # --------------------------------------------------------- # 2. 生成数据分布图 (Data Distribution) # --------------------------------------------------------- try: print("\n正在加载数据集以生成样本分布分析...") from transformers import AutoTokenizer from src.dataset import DataProcessor tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL) processor = DataProcessor(tokenizer) # 尝试从 data 目录加载处理好的数据 (快) dataset = processor.get_processed_dataset(cache_dir=Config.DATA_DIR) # 生成带时间戳的文件名 timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") dist_save_path = os.path.join(Config.RESULTS_DIR, "images", f"data_distribution_{timestamp}.png") # 绘图并保存 plot_data_distribution(dataset, save_path=dist_save_path) print(f"数据样本分布分析已保存至: {dist_save_path}") except Exception as e: print(f"无法生成数据分布图 (可能是数据尚未下载或处理): {e}") # --------------------------------------------------------- # 3. 生成训练曲线 (Training History) # --------------------------------------------------------- import glob # 找最新的 checkpoints search_paths = [ Config.OUTPUT_DIR, os.path.join(Config.RESULTS_DIR, "checkpoint-*") ] candidates = [] for p in search_paths: candidates.extend(glob.glob(p)) if candidates: # 找最新的 candidates.sort(key=os.path.getmtime) latest_ckpt = candidates[-1] print(f"Loading logs from: {latest_ckpt}") load_and_plot_logs(latest_ckpt) else: print("未找到任何 checkpoint 或 trainer_state.json 日志文件。")