stmasson commited on
Commit
e81fa0f
·
verified ·
1 Parent(s): 16f173b

Upload scripts/train_alizee_v2_stage2_dpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_alizee_v2_stage2_dpo.py +223 -0
scripts/train_alizee_v2_stage2_dpo.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # dependencies = [
4
+ # "trl>=0.17.0",
5
+ # "peft>=0.14.0",
6
+ # "transformers>=4.48.0",
7
+ # "accelerate>=0.35.0",
8
+ # "bitsandbytes>=0.45.0",
9
+ # "trackio",
10
+ # "datasets>=3.0.0",
11
+ # "flash-attn>=2.5.0",
12
+ # ]
13
+ # ///
14
+
15
+ """
16
+ Stage 2: Light DPO Refresh for Alizee-Coder-Devstral-2-Small
17
+
18
+ Conservative DPO (beta=0.1, lr=5e-6) using CodeUltraFeedback to restore alignment
19
+ after reasoning SFT. This stage is OPTIONAL - run only if evaluation shows
20
+ alignment degradation.
21
+
22
+ Key settings (from user spec):
23
+ - beta=0.1 (conservative KL penalty)
24
+ - learning_rate=5e-6 (very low to preserve Stage 1 gains)
25
+ """
26
+
27
+ import os
28
+ import trackio
29
+ from datasets import load_dataset
30
+ from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, PeftModel
31
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
32
+ from trl import DPOTrainer, DPOConfig
33
+
34
+ # Configuration
35
+ MODEL_NAME = "stmasson/alizee-coder-devstral-2-small-stage1" # Output from Stage 1
36
+ OUTPUT_REPO = "stmasson/alizee-coder-devstral-2-small-stage2"
37
+
38
+ # DPO hyperparameters (conservative as specified)
39
+ BETA = 0.1 # KL penalty - higher = stay closer to reference
40
+ LEARNING_RATE = 5e-6 # Very low LR for alignment refresh
41
+ EFFECTIVE_BATCH_SIZE = 64
42
+ PER_DEVICE_BATCH = 1
43
+ GRADIENT_ACCUMULATION = EFFECTIVE_BATCH_SIZE // PER_DEVICE_BATCH
44
+ MAX_SEQ_LENGTH = 8192 # Shorter context for DPO
45
+ NUM_EPOCHS = 1
46
+
47
+ print("=" * 60)
48
+ print("Stage 2: Light DPO Refresh (Optional)")
49
+ print("=" * 60)
50
+ print(f"Base model: {MODEL_NAME}")
51
+ print(f"Output: {OUTPUT_REPO}")
52
+ print(f"Beta (KL penalty): {BETA}")
53
+ print(f"Learning rate: {LEARNING_RATE}")
54
+ print("=" * 60)
55
+
56
+ # Load tokenizer
57
+ print("\n📝 Loading tokenizer...")
58
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
59
+ if tokenizer.pad_token is None:
60
+ tokenizer.pad_token = tokenizer.eos_token
61
+ tokenizer.padding_side = "left" # DPO prefers left padding
62
+
63
+ # QLoRA quantization config
64
+ print("\n⚙️ Configuring 4-bit quantization...")
65
+ bnb_config = BitsAndBytesConfig(
66
+ load_in_4bit=True,
67
+ bnb_4bit_quant_type="nf4",
68
+ bnb_4bit_compute_dtype="bfloat16",
69
+ bnb_4bit_use_double_quant=True,
70
+ )
71
+
72
+ # Load model (which already has merged LoRA from Stage 1)
73
+ print("\n🔄 Loading Stage 1 model with QLoRA...")
74
+ model = AutoModelForCausalLM.from_pretrained(
75
+ MODEL_NAME,
76
+ quantization_config=bnb_config,
77
+ device_map="auto",
78
+ trust_remote_code=True,
79
+ attn_implementation="flash_attention_2",
80
+ torch_dtype="auto",
81
+ )
82
+ model = prepare_model_for_kbit_training(model)
83
+
84
+ # LoRA configuration (smaller for DPO - just alignment refresh)
85
+ print("\n🎯 Configuring LoRA adapters for DPO...")
86
+ lora_config = LoraConfig(
87
+ r=32, # Smaller rank for DPO refresh
88
+ lora_alpha=64,
89
+ lora_dropout=0.05,
90
+ bias="none",
91
+ task_type="CAUSAL_LM",
92
+ target_modules=[
93
+ "q_proj", "k_proj", "v_proj", "o_proj",
94
+ "gate_proj", "up_proj", "down_proj"
95
+ ],
96
+ )
97
+
98
+ model = get_peft_model(model, lora_config)
99
+ model.print_trainable_parameters()
100
+
101
+ # Load CodeUltraFeedback dataset
102
+ print("\n📦 Loading CodeUltraFeedback dataset...")
103
+ dataset = load_dataset("RLHFlow/CodeUltraFeedback-standard", split="train")
104
+ print(f" Loaded {len(dataset)} preference pairs")
105
+
106
+ def format_for_dpo(example):
107
+ """Format CodeUltraFeedback for DPO training.
108
+
109
+ CodeUltraFeedback-standard has:
110
+ - prompt: the coding instruction
111
+ - chosen: the better response
112
+ - rejected: the worse response
113
+ """
114
+ return {
115
+ "prompt": example["prompt"],
116
+ "chosen": example["chosen"],
117
+ "rejected": example["rejected"],
118
+ }
119
+
120
+ # Format dataset
121
+ print("\n🔄 Formatting dataset for DPO...")
122
+ formatted_dataset = dataset.map(
123
+ format_for_dpo,
124
+ remove_columns=[col for col in dataset.column_names if col not in ["prompt", "chosen", "rejected"]],
125
+ num_proc=4,
126
+ )
127
+
128
+ # Create train/eval split
129
+ print(" Creating train/eval split...")
130
+ split_dataset = formatted_dataset.train_test_split(test_size=0.05, seed=42)
131
+ train_dataset = split_dataset["train"]
132
+ eval_dataset = split_dataset["test"]
133
+ print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
134
+
135
+ # DPO Training configuration
136
+ print("\n⚙️ Configuring DPO training...")
137
+ training_config = DPOConfig(
138
+ # Output and Hub settings
139
+ output_dir="alizee-v2-stage2-dpo",
140
+ push_to_hub=True,
141
+ hub_model_id=OUTPUT_REPO,
142
+ hub_strategy="every_save",
143
+ hub_private_repo=False,
144
+
145
+ # DPO-specific
146
+ beta=BETA,
147
+
148
+ # Training parameters
149
+ num_train_epochs=NUM_EPOCHS,
150
+ per_device_train_batch_size=PER_DEVICE_BATCH,
151
+ per_device_eval_batch_size=PER_DEVICE_BATCH,
152
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION,
153
+ learning_rate=LEARNING_RATE,
154
+ max_length=MAX_SEQ_LENGTH,
155
+ max_prompt_length=MAX_SEQ_LENGTH // 2,
156
+
157
+ # Optimization
158
+ warmup_ratio=0.1,
159
+ lr_scheduler_type="cosine",
160
+ optim="adamw_8bit",
161
+ bf16=True,
162
+ gradient_checkpointing=True,
163
+ gradient_checkpointing_kwargs={"use_reentrant": False},
164
+
165
+ # Logging and checkpointing
166
+ logging_steps=10,
167
+ save_strategy="steps",
168
+ save_steps=200,
169
+ save_total_limit=2,
170
+ eval_strategy="steps",
171
+ eval_steps=200,
172
+
173
+ # Monitoring
174
+ report_to="trackio",
175
+ project="alizee-coder-v2",
176
+ run_name="stage2-dpo-refresh",
177
+
178
+ # Other settings
179
+ max_grad_norm=1.0,
180
+ remove_unused_columns=False,
181
+ )
182
+
183
+ # Initialize trainer
184
+ print("\n🎯 Initializing DPO Trainer...")
185
+ trainer = DPOTrainer(
186
+ model=model,
187
+ ref_model=None, # Use implicit reference (copy of model)
188
+ tokenizer=tokenizer,
189
+ train_dataset=train_dataset,
190
+ eval_dataset=eval_dataset,
191
+ args=training_config,
192
+ peft_config=lora_config,
193
+ )
194
+
195
+ # Calculate and display training info
196
+ total_steps = (len(train_dataset) // EFFECTIVE_BATCH_SIZE) * NUM_EPOCHS
197
+ print(f"\n📊 DPO Training Configuration Summary:")
198
+ print(f" Total preference pairs: {len(train_dataset)}")
199
+ print(f" Effective batch size: {EFFECTIVE_BATCH_SIZE}")
200
+ print(f" Total steps: {total_steps}")
201
+ print(f" Beta (KL penalty): {BETA}")
202
+ print(f" Learning rate: {LEARNING_RATE}")
203
+
204
+ # Start training
205
+ print("\n🚀 Starting Stage 2 DPO Refresh...")
206
+ print(" This should take 2-4 hours on A100-80GB")
207
+ print(" Monitor at: https://huggingface.co/spaces/stmasson/trackio")
208
+ print("=" * 60)
209
+
210
+ trainer.train()
211
+
212
+ # Save final model
213
+ print("\n💾 Pushing Stage 2 model to Hub...")
214
+ trainer.push_to_hub()
215
+
216
+ # Finish tracking
217
+ trackio.finish()
218
+
219
+ print("\n" + "=" * 60)
220
+ print("✅ Stage 2 Complete!")
221
+ print(f" Model saved to: https://huggingface.co/{OUTPUT_REPO}")
222
+ print(" Ready for Stage 3 (adapter merging)")
223
+ print("=" * 60)