File size: 6,800 Bytes
af9853e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import json
import os
from datetime import datetime
# 设置中文字体 (尝试自动寻找可用字体)
def set_chinese_font():
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'PingFang SC', 'Heiti TC']
plt.rcParams['axes.unicode_minus'] = False
def plot_data_distribution(dataset_dict, save_path=None):
"""
绘制数据集中 Positive/Neutral/Negative 的分布饼图
"""
set_chinese_font()
# 统计数量
# 兼容 dataset_dict (DatasetDict) 或 dataset (Dataset)
if hasattr(dataset_dict, 'keys') and 'train' in dataset_dict.keys():
ds = dataset_dict['train']
else:
ds = dataset_dict
# 统计数量
if 'label' in ds.features:
train_labels = ds['label']
elif 'labels' in ds.features:
train_labels = ds['labels']
else:
# Fallback
train_labels = [x.get('label', x.get('labels')) for x in ds]
# 映射回字符串以便显示
id2label = {0: 'Negative (消极)', 1: 'Neutral (中性)', 2: 'Positive (积极)'}
labels_str = [id2label.get(x, str(x)) for x in train_labels]
df = pd.DataFrame({'Label': labels_str})
counts = df['Label'].value_counts()
plt.figure(figsize=(10, 6))
plt.pie(counts, labels=counts.index, autopct='%1.1f%%', startangle=140, colors=sns.color_palette("pastel"))
plt.title('训练集情感分布')
plt.tight_layout()
if save_path:
print(f"Saving distribution plot to {save_path}...")
plt.savefig(save_path)
# plt.show()
def plot_training_history(log_history, save_path=None):
"""
根据 Trainer 的 log_history 绘制 Loss 和 Accuracy 曲线
"""
set_chinese_font()
if not log_history:
print("没有可用的训练日志。")
return
df = pd.DataFrame(log_history)
# 过滤掉没有 loss 或 eval_accuracy 的行
train_loss = df[df['loss'].notna()]
eval_acc = df[df['eval_accuracy'].notna()]
plt.figure(figsize=(14, 5))
# 1. Loss Curve
plt.subplot(1, 2, 1)
plt.plot(train_loss['epoch'], train_loss['loss'], label='Training Loss', color='salmon')
if 'eval_loss' in df.columns:
eval_loss = df[df['eval_loss'].notna()]
plt.plot(eval_loss['epoch'], eval_loss['eval_loss'], label='Validation Loss', color='skyblue')
plt.title('训练损失 (Loss) 曲线')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)
# 2. Accuracy Curve
if not eval_acc.empty:
plt.subplot(1, 2, 2)
plt.plot(eval_acc['epoch'], eval_acc['eval_accuracy'], label='Validation Accuracy', color='lightgreen', marker='o')
plt.title('验证集准确率 (Accuracy)')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)
# 确保目录存在
save_dir = os.path.join(Config.RESULTS_DIR, "images")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
plt.tight_layout()
# 生成时间戳 string,例如: 2024-12-18_14-30-00
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# 默认保存路径
if save_path is None:
save_path = os.path.join(save_dir, f"training_metrics_{timestamp}.png")
print(f"Saving plot to {save_path}...")
plt.savefig(save_path)
# 也可以保存一份 JSON 或 TXT 格式的最终指标
if not eval_acc.empty:
final_acc = eval_acc.iloc[-1]['eval_accuracy']
final_loss = eval_acc.iloc[-1]['eval_loss'] if 'eval_loss' in eval_acc.columns else "N/A"
metrics_file = os.path.join(save_dir, f"metrics_{timestamp}.txt")
with open(metrics_file, "w") as f:
f.write(f"Timestamp: {timestamp}\n")
f.write(f"Final Validation Accuracy: {final_acc:.4f}\n")
f.write(f"Final Validation Loss: {final_loss}\n")
f.write(f"Plot saved to: {os.path.basename(save_path)}\n")
print(f"Saved metrics text to {metrics_file}")
def load_and_plot_logs(log_dir):
"""
从 checkpoint 目录加载 trainer_state.json 并绘图
"""
json_path = os.path.join(log_dir, 'trainer_state.json')
if not os.path.exists(json_path):
print(f"未找到日志文件: {json_path}")
return
with open(json_path, 'r') as f:
data = json.load(f)
plot_training_history(data['log_history'])
if __name__ == "__main__":
import sys
import os # Explicitly import os here if not globally sufficient or for clarity
# 如果直接运行此脚本,解决相对导入问题
# 将上一级目录加入 sys.path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(project_root)
from src.config import Config
# ---------------------------------------------------------
# 2. 生成数据分布图 (Data Distribution)
# ---------------------------------------------------------
try:
print("\n正在加载数据集以生成样本分布分析...")
from transformers import AutoTokenizer
from src.dataset import DataProcessor
tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
processor = DataProcessor(tokenizer)
# 尝试从 data 目录加载处理好的数据 (快)
dataset = processor.get_processed_dataset(cache_dir=Config.DATA_DIR)
# 生成带时间戳的文件名
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
dist_save_path = os.path.join(Config.RESULTS_DIR, "images", f"data_distribution_{timestamp}.png")
# 绘图并保存
plot_data_distribution(dataset, save_path=dist_save_path)
print(f"数据样本分布分析已保存至: {dist_save_path}")
except Exception as e:
print(f"无法生成数据分布图 (可能是数据尚未下载或处理): {e}")
# ---------------------------------------------------------
# 3. 生成训练曲线 (Training History)
# ---------------------------------------------------------
import glob
# 找最新的 checkpoints
search_paths = [
Config.OUTPUT_DIR,
os.path.join(Config.RESULTS_DIR, "checkpoint-*")
]
candidates = []
for p in search_paths:
candidates.extend(glob.glob(p))
if candidates:
# 找最新的
candidates.sort(key=os.path.getmtime)
latest_ckpt = candidates[-1]
print(f"Loading logs from: {latest_ckpt}")
load_and_plot_logs(latest_ckpt)
else:
print("未找到任何 checkpoint 或 trainer_state.json 日志文件。")
|