stmasson commited on
Commit
70f98a4
·
verified ·
1 Parent(s): 807886d

Upload scripts/train_n8n_dpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_n8n_dpo.py +244 -0
scripts/train_n8n_dpo.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "transformers>=4.45.0",
5
+ # "trl>=0.12.0",
6
+ # "peft>=0.13.0",
7
+ # "datasets>=3.0.0",
8
+ # "accelerate>=1.0.0",
9
+ # "bitsandbytes>=0.44.0",
10
+ # "wandb>=0.18.0",
11
+ # "huggingface_hub>=0.26.0",
12
+ # "torch>=2.4.0",
13
+ # "einops>=0.8.0",
14
+ # "sentencepiece>=0.2.0",
15
+ # ]
16
+ # [tool.uv]
17
+ # extra-index-url = ["https://download.pytorch.org/whl/cu124"]
18
+ # ///
19
+ """
20
+ Script d'entraînement DPO pour le modèle n8n Expert.
21
+ À exécuter APRÈS l'entraînement SFT.
22
+
23
+ Usage sur HuggingFace Jobs:
24
+ hf jobs uv run \
25
+ --script train_n8n_dpo.py \
26
+ --flavor h100x1 \
27
+ --name n8n-expert-dpo \
28
+ --timeout 12h \
29
+ --env BASE_MODEL=stmasson/n8n-expert-14b-sft
30
+
31
+ Variables d'environnement:
32
+ - HF_TOKEN: Token HuggingFace
33
+ - BASE_MODEL: Modèle SFT à utiliser comme base
34
+ - WANDB_API_KEY: (optionnel) Pour le tracking
35
+ """
36
+
37
+ import os
38
+ import torch
39
+ from datasets import load_dataset
40
+ from transformers import AutoModelForCausalLM, AutoTokenizer
41
+ from peft import LoraConfig, PeftModel
42
+ from trl import DPOTrainer, DPOConfig
43
+ from huggingface_hub import login
44
+
45
+ # ============================================================================
46
+ # CONFIGURATION
47
+ # ============================================================================
48
+
49
+ # Modèle SFT fine-tuné
50
+ BASE_MODEL = os.environ.get("BASE_MODEL", "stmasson/n8n-expert-14b-sft")
51
+ ORIGINAL_MODEL = os.environ.get("ORIGINAL_MODEL", "Qwen/Qwen2.5-14B-Instruct")
52
+
53
+ # Dataset DPO
54
+ DATASET_REPO = "stmasson/n8n-workflows-thinking"
55
+ DPO_FILE = "n8n_dpo_train.jsonl"
56
+
57
+ # Output
58
+ OUTPUT_DIR = "./n8n-expert-dpo"
59
+ HF_REPO = os.environ.get("HF_REPO", "stmasson/n8n-expert-14b-dpo")
60
+
61
+ # Hyperparamètres DPO
62
+ NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "2"))
63
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1"))
64
+ GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "16"))
65
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "5e-6"))
66
+ BETA = float(os.environ.get("DPO_BETA", "0.1"))
67
+ MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "8192"))
68
+ MAX_PROMPT_LENGTH = int(os.environ.get("MAX_PROMPT_LENGTH", "2048"))
69
+
70
+ # LoRA (plus léger pour DPO)
71
+ LORA_R = int(os.environ.get("LORA_R", "32"))
72
+ LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "64"))
73
+
74
+ # ============================================================================
75
+ # AUTHENTIFICATION
76
+ # ============================================================================
77
+
78
+ print("=" * 60)
79
+ print("ENTRAÎNEMENT DPO - N8N EXPERT")
80
+ print("=" * 60)
81
+
82
+ hf_token = os.environ.get("HF_TOKEN")
83
+ if hf_token:
84
+ login(token=hf_token)
85
+ print("Authentifié sur HuggingFace")
86
+
87
+ wandb_key = os.environ.get("WANDB_API_KEY")
88
+ if wandb_key:
89
+ import wandb
90
+ wandb.login(key=wandb_key)
91
+ report_to = "wandb"
92
+ else:
93
+ report_to = "none"
94
+
95
+ # ============================================================================
96
+ # CHARGEMENT DU MODÈLE
97
+ # ============================================================================
98
+
99
+ print(f"\nChargement du modèle SFT: {BASE_MODEL}")
100
+
101
+ # Charger le modèle de base
102
+ model = AutoModelForCausalLM.from_pretrained(
103
+ BASE_MODEL,
104
+ torch_dtype=torch.bfloat16,
105
+ attn_implementation="flash_attention_2",
106
+ device_map="auto",
107
+ trust_remote_code=True,
108
+ )
109
+
110
+ # Charger le modèle de référence (pour DPO)
111
+ ref_model = AutoModelForCausalLM.from_pretrained(
112
+ BASE_MODEL,
113
+ torch_dtype=torch.bfloat16,
114
+ attn_implementation="flash_attention_2",
115
+ device_map="auto",
116
+ trust_remote_code=True,
117
+ )
118
+
119
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
120
+ if tokenizer.pad_token is None:
121
+ tokenizer.pad_token = tokenizer.eos_token
122
+
123
+ print("Modèle chargé")
124
+
125
+ # ============================================================================
126
+ # CONFIGURATION LORA
127
+ # ============================================================================
128
+
129
+ print(f"\nConfiguration LoRA: r={LORA_R}, alpha={LORA_ALPHA}")
130
+
131
+ lora_config = LoraConfig(
132
+ r=LORA_R,
133
+ lora_alpha=LORA_ALPHA,
134
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
135
+ lora_dropout=0.05,
136
+ bias="none",
137
+ task_type="CAUSAL_LM"
138
+ )
139
+
140
+ # ============================================================================
141
+ # CHARGEMENT DU DATASET DPO
142
+ # ============================================================================
143
+
144
+ print(f"\nChargement du dataset DPO: {DATASET_REPO}")
145
+
146
+ dataset = load_dataset(
147
+ DATASET_REPO,
148
+ data_files={"train": DPO_FILE},
149
+ split="train"
150
+ )
151
+
152
+ print(f"Exemples DPO: {len(dataset)}")
153
+
154
+ # Fonction de formatage pour DPO
155
+ def format_dpo_example(example):
156
+ """
157
+ Format attendu par DPOTrainer:
158
+ - prompt: le prompt de l'utilisateur
159
+ - chosen: la bonne réponse
160
+ - rejected: la mauvaise réponse
161
+ """
162
+ return {
163
+ "prompt": example["prompt"],
164
+ "chosen": example["chosen"],
165
+ "rejected": example["rejected"],
166
+ }
167
+
168
+ # Le dataset devrait déjà être au bon format
169
+ print("\nExemple de données DPO:")
170
+ print(f"Prompt: {dataset[0]['prompt'][:200]}...")
171
+ print(f"Chosen: {dataset[0]['chosen'][:200]}...")
172
+ print(f"Rejected: {dataset[0]['rejected'][:200]}...")
173
+
174
+ # ============================================================================
175
+ # CONFIGURATION D'ENTRAÎNEMENT DPO
176
+ # ============================================================================
177
+
178
+ print(f"\nConfiguration DPO:")
179
+ print(f" - Beta: {BETA}")
180
+ print(f" - Epochs: {NUM_EPOCHS}")
181
+ print(f" - Batch size: {BATCH_SIZE}")
182
+ print(f" - Gradient accumulation: {GRAD_ACCUM}")
183
+ print(f" - Learning rate: {LEARNING_RATE}")
184
+
185
+ dpo_config = DPOConfig(
186
+ output_dir=OUTPUT_DIR,
187
+ num_train_epochs=NUM_EPOCHS,
188
+ per_device_train_batch_size=BATCH_SIZE,
189
+ gradient_accumulation_steps=GRAD_ACCUM,
190
+ learning_rate=LEARNING_RATE,
191
+ beta=BETA,
192
+ lr_scheduler_type="cosine",
193
+ warmup_ratio=0.1,
194
+ bf16=True,
195
+ logging_steps=10,
196
+ save_strategy="steps",
197
+ save_steps=200,
198
+ save_total_limit=3,
199
+ max_length=MAX_LENGTH,
200
+ max_prompt_length=MAX_PROMPT_LENGTH,
201
+ gradient_checkpointing=True,
202
+ gradient_checkpointing_kwargs={"use_reentrant": False},
203
+ report_to=report_to,
204
+ run_name="n8n-expert-dpo",
205
+ hub_model_id=HF_REPO if hf_token else None,
206
+ push_to_hub=bool(hf_token),
207
+ )
208
+
209
+ # ============================================================================
210
+ # ENTRAÎNEMENT DPO
211
+ # ============================================================================
212
+
213
+ print("\nInitialisation du DPO trainer...")
214
+
215
+ trainer = DPOTrainer(
216
+ model=model,
217
+ ref_model=ref_model,
218
+ args=dpo_config,
219
+ train_dataset=dataset,
220
+ peft_config=lora_config,
221
+ tokenizer=tokenizer,
222
+ )
223
+
224
+ print("\n" + "=" * 60)
225
+ print("DÉMARRAGE DE L'ENTRAÎNEMENT DPO")
226
+ print("=" * 60)
227
+
228
+ trainer.train()
229
+
230
+ # ============================================================================
231
+ # SAUVEGARDE
232
+ # ============================================================================
233
+
234
+ print("\nSauvegarde du modèle...")
235
+ trainer.save_model(f"{OUTPUT_DIR}/final")
236
+
237
+ if hf_token:
238
+ print(f"Push vers {HF_REPO}...")
239
+ trainer.push_to_hub()
240
+ print(f"Modèle disponible sur: https://huggingface.co/{HF_REPO}")
241
+
242
+ print("\n" + "=" * 60)
243
+ print("ENTRAÎNEMENT DPO TERMINÉ")
244
+ print("=" * 60)