|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
import pandas as pd |
|
|
import json |
|
|
import os |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
def set_chinese_font(): |
|
|
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'PingFang SC', 'Heiti TC'] |
|
|
plt.rcParams['axes.unicode_minus'] = False |
|
|
|
|
|
def plot_data_distribution(dataset_dict, save_path=None): |
|
|
""" |
|
|
绘制数据集中 Positive/Neutral/Negative 的分布饼图 |
|
|
""" |
|
|
set_chinese_font() |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(dataset_dict, 'keys') and 'train' in dataset_dict.keys(): |
|
|
ds = dataset_dict['train'] |
|
|
else: |
|
|
ds = dataset_dict |
|
|
|
|
|
|
|
|
if 'label' in ds.features: |
|
|
train_labels = ds['label'] |
|
|
elif 'labels' in ds.features: |
|
|
train_labels = ds['labels'] |
|
|
else: |
|
|
|
|
|
train_labels = [x.get('label', x.get('labels')) for x in ds] |
|
|
|
|
|
|
|
|
id2label = {0: 'Negative (消极)', 1: 'Neutral (中性)', 2: 'Positive (积极)'} |
|
|
labels_str = [id2label.get(x, str(x)) for x in train_labels] |
|
|
|
|
|
df = pd.DataFrame({'Label': labels_str}) |
|
|
counts = df['Label'].value_counts() |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
plt.pie(counts, labels=counts.index, autopct='%1.1f%%', startangle=140, colors=sns.color_palette("pastel")) |
|
|
plt.title('训练集情感分布') |
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
print(f"Saving distribution plot to {save_path}...") |
|
|
plt.savefig(save_path) |
|
|
|
|
|
|
|
|
def plot_training_history(log_history, save_path=None): |
|
|
""" |
|
|
根据 Trainer 的 log_history 绘制 Loss 和 Accuracy 曲线 |
|
|
""" |
|
|
set_chinese_font() |
|
|
|
|
|
if not log_history: |
|
|
print("没有可用的训练日志。") |
|
|
return |
|
|
|
|
|
df = pd.DataFrame(log_history) |
|
|
|
|
|
|
|
|
train_loss = df[df['loss'].notna()] |
|
|
eval_acc = df[df['eval_accuracy'].notna()] |
|
|
|
|
|
plt.figure(figsize=(14, 5)) |
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 1) |
|
|
plt.plot(train_loss['epoch'], train_loss['loss'], label='Training Loss', color='salmon') |
|
|
if 'eval_loss' in df.columns: |
|
|
eval_loss = df[df['eval_loss'].notna()] |
|
|
plt.plot(eval_loss['epoch'], eval_loss['eval_loss'], label='Validation Loss', color='skyblue') |
|
|
plt.title('训练损失 (Loss) 曲线') |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Loss') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
if not eval_acc.empty: |
|
|
plt.subplot(1, 2, 2) |
|
|
plt.plot(eval_acc['epoch'], eval_acc['eval_accuracy'], label='Validation Accuracy', color='lightgreen', marker='o') |
|
|
plt.title('验证集准确率 (Accuracy)') |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Accuracy') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
save_dir = os.path.join(Config.RESULTS_DIR, "images") |
|
|
if not os.path.exists(save_dir): |
|
|
os.makedirs(save_dir) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
|
|
|
|
|
|
if save_path is None: |
|
|
save_path = os.path.join(save_dir, f"training_metrics_{timestamp}.png") |
|
|
|
|
|
print(f"Saving plot to {save_path}...") |
|
|
plt.savefig(save_path) |
|
|
|
|
|
|
|
|
if not eval_acc.empty: |
|
|
final_acc = eval_acc.iloc[-1]['eval_accuracy'] |
|
|
final_loss = eval_acc.iloc[-1]['eval_loss'] if 'eval_loss' in eval_acc.columns else "N/A" |
|
|
metrics_file = os.path.join(save_dir, f"metrics_{timestamp}.txt") |
|
|
with open(metrics_file, "w") as f: |
|
|
f.write(f"Timestamp: {timestamp}\n") |
|
|
f.write(f"Final Validation Accuracy: {final_acc:.4f}\n") |
|
|
f.write(f"Final Validation Loss: {final_loss}\n") |
|
|
f.write(f"Plot saved to: {os.path.basename(save_path)}\n") |
|
|
print(f"Saved metrics text to {metrics_file}") |
|
|
|
|
|
def load_and_plot_logs(log_dir): |
|
|
""" |
|
|
从 checkpoint 目录加载 trainer_state.json 并绘图 |
|
|
""" |
|
|
json_path = os.path.join(log_dir, 'trainer_state.json') |
|
|
if not os.path.exists(json_path): |
|
|
print(f"未找到日志文件: {json_path}") |
|
|
return |
|
|
|
|
|
with open(json_path, 'r') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
plot_training_history(data['log_history']) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
sys.path.append(project_root) |
|
|
|
|
|
from src.config import Config |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
print("\n正在加载数据集以生成样本分布分析...") |
|
|
from transformers import AutoTokenizer |
|
|
from src.dataset import DataProcessor |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL) |
|
|
processor = DataProcessor(tokenizer) |
|
|
|
|
|
dataset = processor.get_processed_dataset(cache_dir=Config.DATA_DIR) |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
dist_save_path = os.path.join(Config.RESULTS_DIR, "images", f"data_distribution_{timestamp}.png") |
|
|
|
|
|
|
|
|
plot_data_distribution(dataset, save_path=dist_save_path) |
|
|
print(f"数据样本分布分析已保存至: {dist_save_path}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"无法生成数据分布图 (可能是数据尚未下载或处理): {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import glob |
|
|
|
|
|
|
|
|
search_paths = [ |
|
|
Config.OUTPUT_DIR, |
|
|
os.path.join(Config.RESULTS_DIR, "checkpoint-*") |
|
|
] |
|
|
|
|
|
candidates = [] |
|
|
for p in search_paths: |
|
|
candidates.extend(glob.glob(p)) |
|
|
|
|
|
if candidates: |
|
|
|
|
|
candidates.sort(key=os.path.getmtime) |
|
|
latest_ckpt = candidates[-1] |
|
|
print(f"Loading logs from: {latest_ckpt}") |
|
|
load_and_plot_logs(latest_ckpt) |
|
|
else: |
|
|
print("未找到任何 checkpoint 或 trainer_state.json 日志文件。") |
|
|
|