Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| PPO RLHF训练脚本 - 基于Teacher模型进行人类偏好对齐 | |
| 输入: SFT Teacher模型 + 人类偏好数据 | |
| 输出: RLHF对齐的Teacher模型 | |
| """ | |
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from datasets import load_dataset, Dataset | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| TrainingArguments, | |
| pipeline, | |
| logging, | |
| ) | |
| from peft import PeftModel, LoraConfig, get_peft_model, TaskType | |
| from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead | |
| import wandb | |
| import numpy as np | |
| from typing import List, Dict, Any | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| logging.set_verbosity(logging.CRITICAL) | |
| class RLHFConfig: | |
| """RLHF训练配置""" | |
| # 模型路径 | |
| teacher_model_path = "./merged_model" # 之前SFT训练的Teacher模型 | |
| reward_model_name = "OpenAssistant/reward-model-deberta-v3-large-v2" # 奖励模型 | |
| # PPO训练参数 | |
| learning_rate = 1e-5 | |
| mini_batch_size = 1 | |
| batch_size = 8 | |
| gradient_accumulation_steps = 8 | |
| ppo_epochs = 4 | |
| max_grad_norm = 1.0 | |
| # PPO特定参数 | |
| init_kl_coef = 0.02 | |
| target_kl = 0.01 | |
| adap_kl_ctrl = True | |
| clip_reward_value = 5.0 | |
| cliprange = 0.2 | |
| cliprange_value = 0.2 | |
| gamma = 1.0 | |
| lam = 0.95 | |
| # 生成参数 | |
| max_new_tokens = 150 | |
| temperature = 0.7 | |
| top_p = 0.9 | |
| do_sample = True | |
| # 训练控制 | |
| total_episodes = 1000 | |
| save_freq = 100 | |
| eval_freq = 50 | |
| output_dir = "./rlhf_teacher_model" | |
| # LoRA参数(如果使用LoRA进行RLHF) | |
| use_lora = True | |
| lora_r = 16 | |
| lora_alpha = 32 | |
| lora_dropout = 0.1 | |
| class RewardModelWrapper: | |
| """奖励模型包装器""" | |
| def __init__(self, model_name: str, device: str = "cuda"): | |
| self.device = device | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| self.model.eval() | |
| # 设置pad token | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| def get_reward(self, prompts: List[str], responses: List[str]) -> List[float]: | |
| """计算奖励分数""" | |
| inputs = [] | |
| for prompt, response in zip(prompts, responses): | |
| # 格式化为对话格式 | |
| text = f"Human: {prompt}\n\nAssistant: {response}" | |
| inputs.append(text) | |
| # 批量推理 | |
| with torch.no_grad(): | |
| encoded = self.tokenizer( | |
| inputs, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| outputs = self.model(**encoded) | |
| rewards = outputs.logits.squeeze(-1).cpu().tolist() | |
| return rewards | |
| def load_preference_dataset(): | |
| """加载偏好数据集""" | |
| print("📥 Loading preference dataset...") | |
| # 可以使用多个数据源 | |
| datasets_config = [ | |
| { | |
| "name": "Anthropic/hh-rlhf", | |
| "split": "train", | |
| "weight": 0.7 | |
| }, | |
| { | |
| "name": "OpenAssistant/oasst1", | |
| "split": "train", | |
| "weight": 0.3 | |
| } | |
| ] | |
| all_prompts = [] | |
| for config in datasets_config: | |
| try: | |
| dataset = load_dataset(config["name"], split=config["split"]) | |
| # 处理不同数据集格式 | |
| if config["name"] == "Anthropic/hh-rlhf": | |
| prompts = extract_prompts_from_hh(dataset) | |
| else: | |
| prompts = extract_prompts_from_oasst(dataset) | |
| # 按权重采样 | |
| sample_size = int(len(prompts) * config["weight"]) | |
| prompts = prompts[:sample_size] | |
| all_prompts.extend(prompts) | |
| print(f"✅ Loaded {len(prompts)} prompts from {config['name']}") | |
| except Exception as e: | |
| print(f"⚠️ Failed to load {config['name']}: {e}") | |
| # 创建Dataset对象 | |
| return Dataset.from_dict({"prompt": all_prompts}) | |
| def extract_prompts_from_hh(dataset): | |
| """从HH-RLHF数据集提取提示""" | |
| prompts = [] | |
| for item in dataset: | |
| # HH-RLHF格式解析 | |
| text = item.get("chosen", "") | |
| if "Human:" in text: | |
| prompt = text.split("Human:")[-1].split("Assistant:")[0].strip() | |
| if len(prompt) > 10: # 过滤太短的提示 | |
| prompts.append(prompt) | |
| return prompts | |
| def extract_prompts_from_oasst(dataset): | |
| """从OpenAssistant数据集提取提示""" | |
| prompts = [] | |
| for item in dataset: | |
| if item.get("role") == "prompter": | |
| prompt = item.get("text", "").strip() | |
| if len(prompt) > 10: | |
| prompts.append(prompt) | |
| return prompts | |
| def prepare_teacher_model(config: RLHFConfig): | |
| """准备Teacher模型用于RLHF""" | |
| print("🤖 Preparing teacher model for RLHF...") | |
| # 加载tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # 加载基础模型 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.teacher_model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| # 如果使用LoRA进行RLHF | |
| if config.use_lora: | |
| print("🔧 Adding LoRA for RLHF training...") | |
| lora_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| inference_mode=False, | |
| r=config.lora_r, | |
| lora_alpha=config.lora_alpha, | |
| lora_dropout=config.lora_dropout, | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| ] | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() | |
| # 包装为带价值头的模型 | |
| model = AutoModelForCausalLMWithValueHead.from_pretrained( | |
| model, | |
| torch_dtype=torch.float16, | |
| ) | |
| # 创建参考模型(冻结) | |
| ref_model = AutoModelForCausalLM.from_pretrained( | |
| config.teacher_model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| ) | |
| ref_model.eval() | |
| return model, ref_model, tokenizer | |
| def create_ppo_trainer(model, ref_model, tokenizer, config: RLHFConfig): | |
| """创建PPO训练器""" | |
| print("🏋️ Creating PPO trainer...") | |
| ppo_config = PPOConfig( | |
| model_name=config.teacher_model_path, | |
| learning_rate=config.learning_rate, | |
| mini_batch_size=config.mini_batch_size, | |
| batch_size=config.batch_size, | |
| gradient_accumulation_steps=config.gradient_accumulation_steps, | |
| ppo_epochs=config.ppo_epochs, | |
| max_grad_norm=config.max_grad_norm, | |
| init_kl_coef=config.init_kl_coef, | |
| target_kl=config.target_kl, | |
| adap_kl_ctrl=config.adap_kl_ctrl, | |
| clip_reward_value=config.clip_reward_value, | |
| cliprange=config.cliprange, | |
| cliprange_value=config.cliprange_value, | |
| gamma=config.gamma, | |
| lam=config.lam, | |
| remove_unused_columns=False, | |
| log_with="wandb" if wandb.run else None, | |
| ) | |
| trainer = PPOTrainer( | |
| config=ppo_config, | |
| model=model, | |
| ref_model=ref_model, | |
| tokenizer=tokenizer, | |
| ) | |
| return trainer | |
| def format_prompt_for_generation(prompt: str) -> str: | |
| """格式化提示用于生成""" | |
| return f"### Human: {prompt}\n### Assistant:" | |
| def run_ppo_training(): | |
| """主要的PPO训练循环""" | |
| print("🚀 Starting PPO RLHF Training...") | |
| # 初始化wandb | |
| wandb.init( | |
| project="rlhf-teacher-training", | |
| config=vars(RLHFConfig), | |
| name="ppo-teacher-rlhf" | |
| ) | |
| config = RLHFConfig() | |
| # 准备模型 | |
| model, ref_model, tokenizer = prepare_teacher_model(config) | |
| # 创建PPO训练器 | |
| ppo_trainer = create_ppo_trainer(model, ref_model, tokenizer, config) | |
| # 加载奖励模型 | |
| reward_model = RewardModelWrapper(config.reward_model_name) | |
| # 加载数据集 | |
| dataset = load_preference_dataset() | |
| print(f"📊 Training on {len(dataset)} prompts") | |
| print(f"🎯 Target episodes: {config.total_episodes}") | |
| # 训练循环 | |
| for episode in range(config.total_episodes): | |
| # 随机采样prompts | |
| batch_prompts = np.random.choice( | |
| dataset["prompt"], | |
| size=config.batch_size, | |
| replace=False | |
| ).tolist() | |
| # 格式化输入 | |
| formatted_prompts = [format_prompt_for_generation(p) for p in batch_prompts] | |
| # 生成响应 | |
| prompt_tensors = [] | |
| for prompt in formatted_prompts: | |
| prompt_tensor = tokenizer.encode( | |
| prompt, | |
| return_tensors="pt", | |
| padding=False, | |
| truncation=True, | |
| max_length=256 | |
| ).squeeze() | |
| prompt_tensors.append(prompt_tensor) | |
| # 批量生成 | |
| response_tensors = [] | |
| with torch.no_grad(): | |
| for prompt_tensor in prompt_tensors: | |
| prompt_tensor = prompt_tensor.unsqueeze(0).to(model.device) | |
| response = ppo_trainer.generate( | |
| prompt_tensor, | |
| max_new_tokens=config.max_new_tokens, | |
| temperature=config.temperature, | |
| top_p=config.top_p, | |
| do_sample=config.do_sample, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| # 只保留新生成的部分 | |
| response = response.squeeze()[prompt_tensor.shape[1]:] | |
| response_tensors.append(response) | |
| # 解码响应 | |
| responses = [ | |
| tokenizer.decode(r, skip_special_tokens=True).strip() | |
| for r in response_tensors | |
| ] | |
| # 计算奖励 | |
| rewards = reward_model.get_reward(batch_prompts, responses) | |
| rewards = [torch.tensor(r, dtype=torch.float) for r in rewards] | |
| # PPO训练步骤 | |
| stats = ppo_trainer.step(prompt_tensors, response_tensors, rewards) | |
| # 记录统计信息 | |
| ppo_trainer.log_stats( | |
| stats, | |
| batch_prompts, | |
| [list(p) + list(r) for p, r in zip(prompt_tensors, response_tensors)], | |
| rewards | |
| ) | |
| # 打印进度 | |
| if episode % 10 == 0: | |
| mean_reward = np.mean([r.item() for r in rewards]) | |
| print(f"📈 Episode {episode}: Mean Reward = {mean_reward:.4f}") | |
| # 记录到wandb | |
| wandb.log({ | |
| "episode": episode, | |
| "mean_reward": mean_reward, | |
| "kl_divergence": stats.get("objective/kl", 0), | |
| "policy_loss": stats.get("ppo/loss/policy", 0), | |
| "value_loss": stats.get("ppo/loss/value", 0), | |
| }) | |
| # 评估模型 | |
| if episode % config.eval_freq == 0 and episode > 0: | |
| evaluate_model(ppo_trainer.model, tokenizer, episode) | |
| # 保存检查点 | |
| if episode % config.save_freq == 0 and episode > 0: | |
| save_checkpoint(ppo_trainer.model, tokenizer, config.output_dir, episode) | |
| # 保存最终模型 | |
| print("💾 Saving final RLHF model...") | |
| ppo_trainer.model.save_pretrained(config.output_dir) | |
| tokenizer.save_pretrained(config.output_dir) | |
| wandb.finish() | |
| print("✅ RLHF training completed!") | |
| def evaluate_model(model, tokenizer, episode): | |
| """评估模型性能""" | |
| print(f"🧪 Evaluating model at episode {episode}...") | |
| test_prompts = [ | |
| "Create an advertisement for a revolutionary smartphone with AI capabilities", | |
| "Write marketing copy for an eco-friendly clothing brand", | |
| "Generate a slogan for a fitness app targeting busy professionals", | |
| ] | |
| model.eval() | |
| results = [] | |
| for prompt in test_prompts: | |
| formatted_prompt = format_prompt_for_generation(prompt) | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| generated_text = response[len(formatted_prompt):].strip() | |
| results.append({ | |
| "prompt": prompt, | |
| "response": generated_text | |
| }) | |
| print(f"🔍 Prompt: {prompt}") | |
| print(f"📝 Response: {generated_text}") | |
| print("-" * 80) | |
| model.train() | |
| return results | |
| def save_checkpoint(model, tokenizer, output_dir, episode): | |
| """保存训练检查点""" | |
| checkpoint_dir = f"{output_dir}/checkpoint-{episode}" | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| model.save_pretrained(checkpoint_dir) | |
| tokenizer.save_pretrained(checkpoint_dir) | |
| print(f"💾 Checkpoint saved to {checkpoint_dir}") | |
| def load_checkpoint_and_continue(checkpoint_path): | |
| """从检查点继续训练""" | |
| print(f"📥 Loading checkpoint from {checkpoint_path}") | |
| # 实现检查点恢复逻辑 | |
| pass | |
| if __name__ == "__main__": | |
| # 设置环境变量 | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # 多GPU设置 | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # 检查GPU资源 | |
| if torch.cuda.is_available(): | |
| print(f"🔥 Using {torch.cuda.device_count()} GPUs") | |
| for i in range(torch.cuda.device_count()): | |
| print(f" GPU {i}: {torch.cuda.get_device_name(i)}") | |
| else: | |
| raise RuntimeError("❌ CUDA not available! RLHF requires GPU.") | |
| # 开始训练 | |
| run_ppo_training() |