stmasson commited on
Commit
16f173b
·
verified ·
1 Parent(s): 5d0796b

Upload scripts/train_alizee_v2_stage1_sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_alizee_v2_stage1_sft.py +330 -0
scripts/train_alizee_v2_stage1_sft.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # dependencies = [
4
+ # "trl>=0.17.0",
5
+ # "peft>=0.14.0",
6
+ # "transformers>=4.48.0",
7
+ # "accelerate>=0.35.0",
8
+ # "bitsandbytes>=0.45.0",
9
+ # "trackio",
10
+ # "datasets>=3.0.0",
11
+ # "flash-attn>=2.5.0",
12
+ # ]
13
+ # ///
14
+
15
+ """
16
+ Stage 1: Reasoning Distillation via SFT for Alizee-Coder-Devstral-2-Small
17
+
18
+ Training stmasson/alizee-coder-devstral-1-small on nvidia/OpenCodeReasoning (736K samples)
19
+ with 85% reasoning traces + 15% coding capability preservation from bigcode/starcoderdata.
20
+
21
+ Key features:
22
+ - QLoRA (r=64, alpha=128) for memory-efficient training
23
+ - 32K context window support
24
+ - Gradient checkpointing + Flash Attention 2
25
+ - Automatic data mixing and formatting
26
+ - Trackio monitoring
27
+
28
+ Based on NVIDIA's research: performance improves linearly 25K->736K samples
29
+ """
30
+
31
+ import os
32
+ import random
33
+ import trackio
34
+ from datasets import load_dataset, concatenate_datasets, Dataset
35
+ from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
36
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
37
+ from trl import SFTTrainer, SFTConfig
38
+
39
+ # Configuration
40
+ MODEL_NAME = "stmasson/alizee-coder-devstral-1-small"
41
+ OUTPUT_REPO = "stmasson/alizee-coder-devstral-2-small-stage1"
42
+ FINAL_REPO = "stmasson/alizee-coder-devstral-2-small"
43
+
44
+ # Training hyperparameters (from user spec)
45
+ LEARNING_RATE = 5e-5
46
+ EFFECTIVE_BATCH_SIZE = 256
47
+ PER_DEVICE_BATCH = 1
48
+ GRADIENT_ACCUMULATION = EFFECTIVE_BATCH_SIZE // PER_DEVICE_BATCH
49
+ MAX_SEQ_LENGTH = 32768
50
+ NUM_EPOCHS = 2
51
+ WARMUP_RATIO = 0.05
52
+
53
+ # Data mixing ratio
54
+ REASONING_RATIO = 0.85
55
+ CODING_RATIO = 0.15
56
+
57
+ print("=" * 60)
58
+ print("Stage 1: Reasoning Distillation via SFT")
59
+ print("=" * 60)
60
+ print(f"Base model: {MODEL_NAME}")
61
+ print(f"Output: {OUTPUT_REPO}")
62
+ print(f"Data mix: {REASONING_RATIO*100}% reasoning + {CODING_RATIO*100}% coding")
63
+ print("=" * 60)
64
+
65
+ # Load tokenizer
66
+ print("\n📝 Loading tokenizer...")
67
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
68
+ if tokenizer.pad_token is None:
69
+ tokenizer.pad_token = tokenizer.eos_token
70
+ tokenizer.padding_side = "right"
71
+
72
+ # QLoRA quantization config
73
+ print("\n⚙️ Configuring 4-bit quantization...")
74
+ bnb_config = BitsAndBytesConfig(
75
+ load_in_4bit=True,
76
+ bnb_4bit_quant_type="nf4",
77
+ bnb_4bit_compute_dtype="bfloat16",
78
+ bnb_4bit_use_double_quant=True,
79
+ )
80
+
81
+ # Load model
82
+ print("\n🔄 Loading model with QLoRA...")
83
+ model = AutoModelForCausalLM.from_pretrained(
84
+ MODEL_NAME,
85
+ quantization_config=bnb_config,
86
+ device_map="auto",
87
+ trust_remote_code=True,
88
+ attn_implementation="flash_attention_2",
89
+ torch_dtype="auto",
90
+ )
91
+ model = prepare_model_for_kbit_training(model)
92
+
93
+ # LoRA configuration (r=64, alpha=128 as specified)
94
+ print("\n🎯 Configuring LoRA adapters...")
95
+ lora_config = LoraConfig(
96
+ r=64,
97
+ lora_alpha=128,
98
+ lora_dropout=0.05,
99
+ bias="none",
100
+ task_type="CAUSAL_LM",
101
+ target_modules=[
102
+ "q_proj", "k_proj", "v_proj", "o_proj",
103
+ "gate_proj", "up_proj", "down_proj"
104
+ ],
105
+ )
106
+
107
+ model = get_peft_model(model, lora_config)
108
+ model.print_trainable_parameters()
109
+
110
+ # Load and prepare datasets
111
+ print("\n📦 Loading datasets...")
112
+
113
+ # 1. OpenCodeReasoning (reasoning traces)
114
+ print(" Loading nvidia/OpenCodeReasoning split_0...")
115
+ ocr_split0 = load_dataset("nvidia/OpenCodeReasoning", "split_0", split="train")
116
+ print(f" -> split_0: {len(ocr_split0)} samples")
117
+
118
+ print(" Loading nvidia/OpenCodeReasoning split_1...")
119
+ ocr_split1 = load_dataset("nvidia/OpenCodeReasoning", "split_1", split="train")
120
+ print(f" -> split_1: {len(ocr_split1)} samples")
121
+
122
+ # Combine OpenCodeReasoning splits
123
+ ocr_full = concatenate_datasets([ocr_split0, ocr_split1])
124
+ print(f" Total OpenCodeReasoning: {len(ocr_full)} samples")
125
+
126
+ # 2. Coding capability preservation dataset
127
+ print(" Loading bigcode/starcoderdata (python subset)...")
128
+ # Load a subset of starcoderdata for coding preservation
129
+ coding_ds = load_dataset(
130
+ "bigcode/starcoderdata",
131
+ data_dir="python",
132
+ split="train",
133
+ streaming=True
134
+ )
135
+
136
+ # Calculate how many coding samples we need (15% of total)
137
+ total_reasoning = len(ocr_full)
138
+ num_coding_samples = int(total_reasoning * CODING_RATIO / REASONING_RATIO)
139
+ print(f" Need {num_coding_samples} coding samples for 15% mix")
140
+
141
+ # Take samples from streaming dataset
142
+ print(" Sampling coding data...")
143
+ coding_samples = []
144
+ for i, sample in enumerate(coding_ds):
145
+ if i >= num_coding_samples:
146
+ break
147
+ coding_samples.append(sample)
148
+ if i % 50000 == 0 and i > 0:
149
+ print(f" Collected {i} coding samples...")
150
+
151
+ coding_ds_final = Dataset.from_list(coding_samples)
152
+ print(f" Collected {len(coding_ds_final)} coding samples")
153
+
154
+ # Format functions for different data sources
155
+ def format_reasoning_sample(example):
156
+ """Format OpenCodeReasoning sample for instruction tuning.
157
+
158
+ OpenCodeReasoning has:
159
+ - input: problem description
160
+ - output: reasoning trace / expected output explanation
161
+ - solution: the actual code
162
+ """
163
+ # Create a reasoning-enhanced prompt
164
+ messages = [
165
+ {
166
+ "role": "user",
167
+ "content": f"Solve the following programming problem. Think through it step by step.\n\n{example['input']}"
168
+ },
169
+ {
170
+ "role": "assistant",
171
+ "content": f"Let me think through this problem step by step.\n\n{example['output']}\n\nHere's my solution:\n\n```python\n{example['solution']}\n```"
172
+ }
173
+ ]
174
+
175
+ return {"messages": messages, "source": "reasoning"}
176
+
177
+ def format_coding_sample(example):
178
+ """Format starcoderdata sample for capability preservation."""
179
+ # Extract code content
180
+ content = example.get("content", "")
181
+
182
+ # Create a simple code completion task
183
+ lines = content.split("\n")
184
+ if len(lines) > 10:
185
+ # Split into prompt and completion
186
+ split_point = len(lines) // 3
187
+ prompt_code = "\n".join(lines[:split_point])
188
+ completion_code = "\n".join(lines[split_point:])
189
+
190
+ messages = [
191
+ {
192
+ "role": "user",
193
+ "content": f"Complete the following Python code:\n\n```python\n{prompt_code}\n```"
194
+ },
195
+ {
196
+ "role": "assistant",
197
+ "content": f"```python\n{completion_code}\n```"
198
+ }
199
+ ]
200
+ else:
201
+ # For short snippets, ask to explain and reproduce
202
+ messages = [
203
+ {
204
+ "role": "user",
205
+ "content": f"Write Python code that implements the following:\n\n{content[:200]}..."
206
+ },
207
+ {
208
+ "role": "assistant",
209
+ "content": f"```python\n{content}\n```"
210
+ }
211
+ ]
212
+
213
+ return {"messages": messages, "source": "coding"}
214
+
215
+ # Format datasets
216
+ print("\n🔄 Formatting datasets...")
217
+ print(" Formatting reasoning samples...")
218
+ reasoning_formatted = ocr_full.map(
219
+ format_reasoning_sample,
220
+ remove_columns=ocr_full.column_names,
221
+ num_proc=8,
222
+ desc="Formatting reasoning"
223
+ )
224
+
225
+ print(" Formatting coding samples...")
226
+ coding_formatted = coding_ds_final.map(
227
+ format_coding_sample,
228
+ remove_columns=coding_ds_final.column_names,
229
+ num_proc=4,
230
+ desc="Formatting coding"
231
+ )
232
+
233
+ # Combine and shuffle
234
+ print("\n🔀 Combining and shuffling datasets...")
235
+ combined_dataset = concatenate_datasets([reasoning_formatted, coding_formatted])
236
+ combined_dataset = combined_dataset.shuffle(seed=42)
237
+ print(f" Total training samples: {len(combined_dataset)}")
238
+
239
+ # Create train/eval split
240
+ print(" Creating train/eval split...")
241
+ split_dataset = combined_dataset.train_test_split(test_size=0.005, seed=42)
242
+ train_dataset = split_dataset["train"]
243
+ eval_dataset = split_dataset["test"]
244
+ print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
245
+
246
+ # Training configuration
247
+ print("\n⚙️ Configuring training...")
248
+ training_config = SFTConfig(
249
+ # Output and Hub settings
250
+ output_dir="alizee-v2-stage1-sft",
251
+ push_to_hub=True,
252
+ hub_model_id=OUTPUT_REPO,
253
+ hub_strategy="every_save",
254
+ hub_private_repo=False,
255
+
256
+ # Training parameters
257
+ num_train_epochs=NUM_EPOCHS,
258
+ per_device_train_batch_size=PER_DEVICE_BATCH,
259
+ per_device_eval_batch_size=PER_DEVICE_BATCH,
260
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION,
261
+ learning_rate=LEARNING_RATE,
262
+ max_seq_length=MAX_SEQ_LENGTH,
263
+
264
+ # Optimization
265
+ warmup_ratio=WARMUP_RATIO,
266
+ lr_scheduler_type="cosine",
267
+ optim="adamw_8bit",
268
+ bf16=True,
269
+ gradient_checkpointing=True,
270
+ gradient_checkpointing_kwargs={"use_reentrant": False},
271
+
272
+ # Logging and checkpointing
273
+ logging_steps=10,
274
+ save_strategy="steps",
275
+ save_steps=500,
276
+ save_total_limit=3,
277
+ eval_strategy="steps",
278
+ eval_steps=500,
279
+
280
+ # Monitoring
281
+ report_to="trackio",
282
+ project="alizee-coder-v2",
283
+ run_name="stage1-reasoning-sft",
284
+
285
+ # Other settings
286
+ max_grad_norm=1.0,
287
+ dataloader_num_workers=4,
288
+ remove_unused_columns=True,
289
+ packing=False, # Disable packing for long sequences
290
+ )
291
+
292
+ # Initialize trainer
293
+ print("\n🎯 Initializing SFT Trainer...")
294
+ trainer = SFTTrainer(
295
+ model=model,
296
+ tokenizer=tokenizer,
297
+ train_dataset=train_dataset,
298
+ eval_dataset=eval_dataset,
299
+ args=training_config,
300
+ )
301
+
302
+ # Calculate and display training info
303
+ total_steps = (len(train_dataset) // EFFECTIVE_BATCH_SIZE) * NUM_EPOCHS
304
+ print(f"\n📊 Training Configuration Summary:")
305
+ print(f" Total samples: {len(train_dataset)}")
306
+ print(f" Effective batch size: {EFFECTIVE_BATCH_SIZE}")
307
+ print(f" Steps per epoch: {len(train_dataset) // EFFECTIVE_BATCH_SIZE}")
308
+ print(f" Total steps: {total_steps}")
309
+ print(f" Epochs: {NUM_EPOCHS}")
310
+
311
+ # Start training
312
+ print("\n🚀 Starting Stage 1 Reasoning SFT Training...")
313
+ print(" This will take 16-24+ hours on A100-80GB")
314
+ print(" Monitor at: https://huggingface.co/spaces/stmasson/trackio")
315
+ print("=" * 60)
316
+
317
+ trainer.train()
318
+
319
+ # Save final model
320
+ print("\n💾 Pushing final model to Hub...")
321
+ trainer.push_to_hub()
322
+
323
+ # Finish tracking
324
+ trackio.finish()
325
+
326
+ print("\n" + "=" * 60)
327
+ print("✅ Stage 1 Complete!")
328
+ print(f" Model saved to: https://huggingface.co/{OUTPUT_REPO}")
329
+ print(" Ready for Stage 2 (optional DPO refresh)")
330
+ print("=" * 60)