stmasson commited on
Commit
807886d
·
verified ·
1 Parent(s): ab64aa3

Upload scripts/train_n8n_sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_n8n_sft.py +275 -0
scripts/train_n8n_sft.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "transformers>=4.45.0",
5
+ # "trl>=0.12.0",
6
+ # "peft>=0.13.0",
7
+ # "datasets>=3.0.0",
8
+ # "accelerate>=1.0.0",
9
+ # "bitsandbytes>=0.44.0",
10
+ # "wandb>=0.18.0",
11
+ # "huggingface_hub>=0.26.0",
12
+ # "torch>=2.4.0",
13
+ # "einops>=0.8.0",
14
+ # "sentencepiece>=0.2.0",
15
+ # ]
16
+ # [tool.uv]
17
+ # extra-index-url = ["https://download.pytorch.org/whl/cu124"]
18
+ # ///
19
+ """
20
+ Script d'entraînement SFT pour le modèle n8n Expert.
21
+
22
+ Usage sur HuggingFace Jobs:
23
+ hf jobs uv run \
24
+ --script train_n8n_sft.py \
25
+ --flavor h100x1 \
26
+ --name n8n-expert-sft \
27
+ --timeout 24h
28
+
29
+ Variables d'environnement requises:
30
+ - HF_TOKEN: Token HuggingFace avec accès en écriture
31
+ - WANDB_API_KEY: (optionnel) Pour le tracking W&B
32
+ """
33
+
34
+ import os
35
+ import json
36
+ import torch
37
+ from datasets import load_dataset
38
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
39
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
40
+ from trl import SFTTrainer, SFTConfig
41
+ from huggingface_hub import login
42
+
43
+ # ============================================================================
44
+ # CONFIGURATION
45
+ # ============================================================================
46
+
47
+ # Modèle de base
48
+ MODEL_NAME = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-14B-Instruct")
49
+
50
+ # Dataset
51
+ DATASET_REPO = "stmasson/n8n-agentic-multitask"
52
+ TRAIN_FILE = "data/multitask_large/train.jsonl"
53
+ VAL_FILE = "data/multitask_large/val.jsonl"
54
+
55
+ # Output
56
+ OUTPUT_DIR = "./n8n-expert-sft"
57
+ HF_REPO = os.environ.get("HF_REPO", "stmasson/n8n-expert-14b-sft")
58
+
59
+ # Hyperparamètres
60
+ NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "3"))
61
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "2"))
62
+ GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
63
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-5"))
64
+ MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "8192"))
65
+
66
+ # LoRA
67
+ LORA_R = int(os.environ.get("LORA_R", "64"))
68
+ LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "128"))
69
+ LORA_DROPOUT = float(os.environ.get("LORA_DROPOUT", "0.05"))
70
+
71
+ # Quantization (pour économiser la VRAM)
72
+ USE_4BIT = os.environ.get("USE_4BIT", "false").lower() == "true"
73
+
74
+ # ============================================================================
75
+ # AUTHENTIFICATION
76
+ # ============================================================================
77
+
78
+ print("=" * 60)
79
+ print("ENTRAÎNEMENT SFT - N8N EXPERT")
80
+ print("=" * 60)
81
+
82
+ hf_token = os.environ.get("HF_TOKEN")
83
+ if hf_token:
84
+ login(token=hf_token)
85
+ print("Authentifié sur HuggingFace")
86
+ else:
87
+ print("Warning: HF_TOKEN non défini, push désactivé")
88
+
89
+ wandb_key = os.environ.get("WANDB_API_KEY")
90
+ if wandb_key:
91
+ import wandb
92
+ wandb.login(key=wandb_key)
93
+ report_to = "wandb"
94
+ print("Tracking W&B activé")
95
+ else:
96
+ report_to = "none"
97
+ print("Tracking W&B désactivé")
98
+
99
+ # ============================================================================
100
+ # CHARGEMENT DU MODÈLE
101
+ # ============================================================================
102
+
103
+ print(f"\nChargement du modèle: {MODEL_NAME}")
104
+
105
+ # Configuration quantization si nécessaire
106
+ if USE_4BIT:
107
+ print("Mode 4-bit activé")
108
+ bnb_config = BitsAndBytesConfig(
109
+ load_in_4bit=True,
110
+ bnb_4bit_quant_type="nf4",
111
+ bnb_4bit_compute_dtype=torch.bfloat16,
112
+ bnb_4bit_use_double_quant=True,
113
+ )
114
+ model = AutoModelForCausalLM.from_pretrained(
115
+ MODEL_NAME,
116
+ quantization_config=bnb_config,
117
+ device_map="auto",
118
+ trust_remote_code=True,
119
+ )
120
+ model = prepare_model_for_kbit_training(model)
121
+ else:
122
+ print("Mode bfloat16")
123
+ model = AutoModelForCausalLM.from_pretrained(
124
+ MODEL_NAME,
125
+ torch_dtype=torch.bfloat16,
126
+ attn_implementation="flash_attention_2",
127
+ device_map="auto",
128
+ trust_remote_code=True,
129
+ )
130
+
131
+ # Tokenizer
132
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
133
+ if tokenizer.pad_token is None:
134
+ tokenizer.pad_token = tokenizer.eos_token
135
+ tokenizer.padding_side = "right"
136
+
137
+ print(f"Modèle chargé: {model.config.num_hidden_layers} layers, {model.config.hidden_size} hidden size")
138
+
139
+ # ============================================================================
140
+ # CONFIGURATION LORA
141
+ # ============================================================================
142
+
143
+ print(f"\nConfiguration LoRA: r={LORA_R}, alpha={LORA_ALPHA}")
144
+
145
+ lora_config = LoraConfig(
146
+ r=LORA_R,
147
+ lora_alpha=LORA_ALPHA,
148
+ target_modules=[
149
+ "q_proj", "k_proj", "v_proj", "o_proj",
150
+ "gate_proj", "up_proj", "down_proj"
151
+ ],
152
+ lora_dropout=LORA_DROPOUT,
153
+ bias="none",
154
+ task_type="CAUSAL_LM"
155
+ )
156
+
157
+ # ============================================================================
158
+ # CHARGEMENT DU DATASET
159
+ # ============================================================================
160
+
161
+ print(f"\nChargement du dataset: {DATASET_REPO}")
162
+
163
+ dataset = load_dataset(
164
+ DATASET_REPO,
165
+ data_files={
166
+ "train": TRAIN_FILE,
167
+ "validation": VAL_FILE
168
+ }
169
+ )
170
+
171
+ print(f"Train: {len(dataset['train'])} exemples")
172
+ print(f"Validation: {len(dataset['validation'])} exemples")
173
+
174
+ # Fonction de formatage
175
+ def format_example(example):
176
+ """Formate les messages en texte pour l'entraînement"""
177
+ messages = example["messages"]
178
+ text = tokenizer.apply_chat_template(
179
+ messages,
180
+ tokenize=False,
181
+ add_generation_prompt=False
182
+ )
183
+ return {"text": text}
184
+
185
+ # Appliquer le formatage
186
+ print("Formatage des données...")
187
+ dataset = dataset.map(format_example, remove_columns=dataset["train"].column_names)
188
+
189
+ # Afficher un exemple
190
+ print("\nExemple de données formatées:")
191
+ print(dataset["train"][0]["text"][:500] + "...")
192
+
193
+ # ============================================================================
194
+ # CONFIGURATION D'ENTRAÎNEMENT
195
+ # ============================================================================
196
+
197
+ print(f"\nConfiguration d'entraînement:")
198
+ print(f" - Epochs: {NUM_EPOCHS}")
199
+ print(f" - Batch size: {BATCH_SIZE}")
200
+ print(f" - Gradient accumulation: {GRAD_ACCUM}")
201
+ print(f" - Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")
202
+ print(f" - Learning rate: {LEARNING_RATE}")
203
+ print(f" - Max sequence length: {MAX_SEQ_LENGTH}")
204
+
205
+ training_args = SFTConfig(
206
+ output_dir=OUTPUT_DIR,
207
+ num_train_epochs=NUM_EPOCHS,
208
+ per_device_train_batch_size=BATCH_SIZE,
209
+ per_device_eval_batch_size=BATCH_SIZE,
210
+ gradient_accumulation_steps=GRAD_ACCUM,
211
+ learning_rate=LEARNING_RATE,
212
+ lr_scheduler_type="cosine",
213
+ warmup_ratio=0.1,
214
+ weight_decay=0.01,
215
+ bf16=True,
216
+ tf32=True,
217
+ logging_steps=10,
218
+ save_strategy="steps",
219
+ save_steps=500,
220
+ save_total_limit=3,
221
+ eval_strategy="steps",
222
+ eval_steps=500,
223
+ max_seq_length=MAX_SEQ_LENGTH,
224
+ packing=True,
225
+ gradient_checkpointing=True,
226
+ gradient_checkpointing_kwargs={"use_reentrant": False},
227
+ dataset_text_field="text",
228
+ report_to=report_to,
229
+ run_name="n8n-expert-sft",
230
+ hub_model_id=HF_REPO if hf_token else None,
231
+ push_to_hub=bool(hf_token),
232
+ hub_strategy="checkpoint",
233
+ )
234
+
235
+ # ============================================================================
236
+ # ENTRAÎNEMENT
237
+ # ============================================================================
238
+
239
+ print("\nInitialisation du trainer...")
240
+
241
+ trainer = SFTTrainer(
242
+ model=model,
243
+ args=training_args,
244
+ train_dataset=dataset["train"],
245
+ eval_dataset=dataset["validation"],
246
+ peft_config=lora_config,
247
+ tokenizer=tokenizer,
248
+ )
249
+
250
+ # Afficher les paramètres entraînables
251
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
252
+ total_params = sum(p.numel() for p in model.parameters())
253
+ print(f"\nParamètres entraînables: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
254
+
255
+ print("\n" + "=" * 60)
256
+ print("DÉMARRAGE DE L'ENTRAÎNEMENT")
257
+ print("=" * 60)
258
+
259
+ trainer.train()
260
+
261
+ # ============================================================================
262
+ # SAUVEGARDE
263
+ # ============================================================================
264
+
265
+ print("\nSauvegarde du modèle...")
266
+ trainer.save_model(f"{OUTPUT_DIR}/final")
267
+
268
+ if hf_token:
269
+ print(f"Push vers {HF_REPO}...")
270
+ trainer.push_to_hub()
271
+ print(f"Modèle disponible sur: https://huggingface.co/{HF_REPO}")
272
+
273
+ print("\n" + "=" * 60)
274
+ print("ENTRAÎNEMENT TERMINÉ")
275
+ print("=" * 60)