stmasson commited on
Commit
471ffcc
·
verified ·
1 Parent(s): 80fe464

Upload scripts/train_n8n_sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_n8n_sft.py +53 -40
scripts/train_n8n_sft.py CHANGED
@@ -34,17 +34,17 @@ Variables d'environnement requises:
34
  import os
35
  import json
36
  import torch
37
- from datasets import load_dataset
38
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
39
  from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
40
  from trl import SFTTrainer, SFTConfig
41
- from huggingface_hub import login
42
 
43
  # ============================================================================
44
  # CONFIGURATION
45
  # ============================================================================
46
 
47
- # Modèle de base
48
  MODEL_NAME = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-14B-Instruct")
49
 
50
  # Dataset
@@ -56,7 +56,7 @@ VAL_FILE = "data/multitask_large/val.jsonl"
56
  OUTPUT_DIR = "./n8n-expert-sft"
57
  HF_REPO = os.environ.get("HF_REPO", "stmasson/n8n-expert-14b-sft")
58
 
59
- # Hyperparamètres
60
  NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "3"))
61
  BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "2"))
62
  GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
@@ -68,7 +68,7 @@ LORA_R = int(os.environ.get("LORA_R", "64"))
68
  LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "128"))
69
  LORA_DROPOUT = float(os.environ.get("LORA_DROPOUT", "0.05"))
70
 
71
- # Quantization (pour économiser la VRAM)
72
  USE_4BIT = os.environ.get("USE_4BIT", "false").lower() == "true"
73
 
74
  # ============================================================================
@@ -76,29 +76,29 @@ USE_4BIT = os.environ.get("USE_4BIT", "false").lower() == "true"
76
  # ============================================================================
77
 
78
  print("=" * 60)
79
- print("ENTRAÎNEMENT SFT - N8N EXPERT")
80
  print("=" * 60)
81
 
82
  hf_token = os.environ.get("HF_TOKEN")
83
  if hf_token:
84
  login(token=hf_token)
85
- print("Authentifié sur HuggingFace")
86
  else:
87
- print("Warning: HF_TOKEN non défini, push désactivé")
88
 
89
- # Désactivé wandb pour éviter les conflits de dépendances
90
  report_to = "none"
91
- print("Tracking désactivé (pas de wandb)")
92
 
93
  # ============================================================================
94
- # CHARGEMENT DU MODÈLE
95
  # ============================================================================
96
 
97
- print(f"\nChargement du modèle: {MODEL_NAME}")
98
 
99
- # Configuration quantization si nécessaire
100
  if USE_4BIT:
101
- print("Mode 4-bit activé")
102
  bnb_config = BitsAndBytesConfig(
103
  load_in_4bit=True,
104
  bnb_4bit_quant_type="nf4",
@@ -128,7 +128,7 @@ if tokenizer.pad_token is None:
128
  tokenizer.pad_token = tokenizer.eos_token
129
  tokenizer.padding_side = "right"
130
 
131
- print(f"Modèle chargé: {model.config.num_hidden_layers} layers, {model.config.hidden_size} hidden size")
132
 
133
  # ============================================================================
134
  # CONFIGURATION LORA
@@ -149,30 +149,43 @@ lora_config = LoraConfig(
149
  )
150
 
151
  # ============================================================================
152
- # CHARGEMENT DU DATASET
153
  # ============================================================================
154
 
155
  print(f"\nChargement du dataset: {DATASET_REPO}")
156
 
157
- # Charger train et validation séparément pour éviter les problèmes de schéma
158
- # (les colonnes metadata.node_types peuvent différer entre les splits)
159
- train_dataset = load_dataset(
160
- DATASET_REPO,
161
- data_files={"train": TRAIN_FILE},
162
- split="train"
163
- )
164
- val_dataset = load_dataset(
165
- DATASET_REPO,
166
- data_files={"train": VAL_FILE},
167
- split="train"
168
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  print(f"Train: {len(train_dataset)} exemples")
171
  print(f"Validation: {len(val_dataset)} exemples")
172
 
173
  # Fonction de formatage
174
  def format_example(example):
175
- """Formate les messages en texte pour l'entraînement"""
176
  messages = example["messages"]
177
  text = tokenizer.apply_chat_template(
178
  messages,
@@ -182,19 +195,19 @@ def format_example(example):
182
  return {"text": text}
183
 
184
  # Appliquer le formatage
185
- print("Formatage des données...")
186
  train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)
187
  val_dataset = val_dataset.map(format_example, remove_columns=val_dataset.column_names)
188
 
189
  # Afficher un exemple
190
- print("\nExemple de données formatées:")
191
  print(train_dataset[0]["text"][:500] + "...")
192
 
193
  # ============================================================================
194
- # CONFIGURATION D'ENTRAÎNEMENT
195
  # ============================================================================
196
 
197
- print(f"\nConfiguration d'entraînement:")
198
  print(f" - Epochs: {NUM_EPOCHS}")
199
  print(f" - Batch size: {BATCH_SIZE}")
200
  print(f" - Gradient accumulation: {GRAD_ACCUM}")
@@ -233,7 +246,7 @@ training_args = SFTConfig(
233
  )
234
 
235
  # ============================================================================
236
- # ENTRAÎNEMENT
237
  # ============================================================================
238
 
239
  print("\nInitialisation du trainer...")
@@ -247,13 +260,13 @@ trainer = SFTTrainer(
247
  tokenizer=tokenizer,
248
  )
249
 
250
- # Afficher les paramètres entraînables
251
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
252
  total_params = sum(p.numel() for p in model.parameters())
253
- print(f"\nParamètres entraînables: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
254
 
255
  print("\n" + "=" * 60)
256
- print("DÉMARRAGE DE L'ENTRAÎNEMENT")
257
  print("=" * 60)
258
 
259
  trainer.train()
@@ -262,14 +275,14 @@ trainer.train()
262
  # SAUVEGARDE
263
  # ============================================================================
264
 
265
- print("\nSauvegarde du modèle...")
266
  trainer.save_model(f"{OUTPUT_DIR}/final")
267
 
268
  if hf_token:
269
  print(f"Push vers {HF_REPO}...")
270
  trainer.push_to_hub()
271
- print(f"Modèle disponible sur: https://huggingface.co/{HF_REPO}")
272
 
273
  print("\n" + "=" * 60)
274
- print("ENTRAÎNEMENT TERMINÉ")
275
  print("=" * 60)
 
34
  import os
35
  import json
36
  import torch
37
+ from datasets import Dataset
38
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
39
  from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
40
  from trl import SFTTrainer, SFTConfig
41
+ from huggingface_hub import login, hf_hub_download
42
 
43
  # ============================================================================
44
  # CONFIGURATION
45
  # ============================================================================
46
 
47
+ # Modele de base
48
  MODEL_NAME = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-14B-Instruct")
49
 
50
  # Dataset
 
56
  OUTPUT_DIR = "./n8n-expert-sft"
57
  HF_REPO = os.environ.get("HF_REPO", "stmasson/n8n-expert-14b-sft")
58
 
59
+ # Hyperparametres
60
  NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "3"))
61
  BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "2"))
62
  GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
 
68
  LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "128"))
69
  LORA_DROPOUT = float(os.environ.get("LORA_DROPOUT", "0.05"))
70
 
71
+ # Quantization (pour economiser la VRAM)
72
  USE_4BIT = os.environ.get("USE_4BIT", "false").lower() == "true"
73
 
74
  # ============================================================================
 
76
  # ============================================================================
77
 
78
  print("=" * 60)
79
+ print("ENTRAINEMENT SFT - N8N EXPERT")
80
  print("=" * 60)
81
 
82
  hf_token = os.environ.get("HF_TOKEN")
83
  if hf_token:
84
  login(token=hf_token)
85
+ print("Authentifie sur HuggingFace")
86
  else:
87
+ print("Warning: HF_TOKEN non defini, push desactive")
88
 
89
+ # Desactive wandb pour eviter les conflits de dependances
90
  report_to = "none"
91
+ print("Tracking desactive (pas de wandb)")
92
 
93
  # ============================================================================
94
+ # CHARGEMENT DU MODELE
95
  # ============================================================================
96
 
97
+ print(f"\nChargement du modele: {MODEL_NAME}")
98
 
99
+ # Configuration quantization si necessaire
100
  if USE_4BIT:
101
+ print("Mode 4-bit active")
102
  bnb_config = BitsAndBytesConfig(
103
  load_in_4bit=True,
104
  bnb_4bit_quant_type="nf4",
 
128
  tokenizer.pad_token = tokenizer.eos_token
129
  tokenizer.padding_side = "right"
130
 
131
+ print(f"Modele charge: {model.config.num_hidden_layers} layers, {model.config.hidden_size} hidden size")
132
 
133
  # ============================================================================
134
  # CONFIGURATION LORA
 
149
  )
150
 
151
  # ============================================================================
152
+ # CHARGEMENT DU DATASET (FIX: chargement direct JSON pour eviter les conflits de schema)
153
  # ============================================================================
154
 
155
  print(f"\nChargement du dataset: {DATASET_REPO}")
156
 
157
+ def load_jsonl_dataset(repo_id: str, filename: str) -> Dataset:
158
+ """
159
+ Charge un dataset JSONL directement en ne gardant que la colonne 'messages'.
160
+ Evite les problemes de schema avec les colonnes struct comme 'nodes_used'.
161
+ """
162
+ # Telecharger le fichier
163
+ local_path = hf_hub_download(
164
+ repo_id=repo_id,
165
+ filename=filename,
166
+ repo_type="dataset"
167
+ )
168
+
169
+ # Lire le JSONL et extraire uniquement 'messages'
170
+ messages_list = []
171
+ with open(local_path, 'r', encoding='utf-8') as f:
172
+ for line in f:
173
+ data = json.loads(line)
174
+ messages_list.append({"messages": data["messages"]})
175
+
176
+ # Creer le Dataset
177
+ return Dataset.from_list(messages_list)
178
+
179
+ # Charger train et validation
180
+ train_dataset = load_jsonl_dataset(DATASET_REPO, TRAIN_FILE)
181
+ val_dataset = load_jsonl_dataset(DATASET_REPO, VAL_FILE)
182
 
183
  print(f"Train: {len(train_dataset)} exemples")
184
  print(f"Validation: {len(val_dataset)} exemples")
185
 
186
  # Fonction de formatage
187
  def format_example(example):
188
+ """Formate les messages en texte pour l'entrainement"""
189
  messages = example["messages"]
190
  text = tokenizer.apply_chat_template(
191
  messages,
 
195
  return {"text": text}
196
 
197
  # Appliquer le formatage
198
+ print("Formatage des donnees...")
199
  train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)
200
  val_dataset = val_dataset.map(format_example, remove_columns=val_dataset.column_names)
201
 
202
  # Afficher un exemple
203
+ print("\nExemple de donnees formatees:")
204
  print(train_dataset[0]["text"][:500] + "...")
205
 
206
  # ============================================================================
207
+ # CONFIGURATION D'ENTRAINEMENT
208
  # ============================================================================
209
 
210
+ print(f"\nConfiguration d'entrainement:")
211
  print(f" - Epochs: {NUM_EPOCHS}")
212
  print(f" - Batch size: {BATCH_SIZE}")
213
  print(f" - Gradient accumulation: {GRAD_ACCUM}")
 
246
  )
247
 
248
  # ============================================================================
249
+ # ENTRAINEMENT
250
  # ============================================================================
251
 
252
  print("\nInitialisation du trainer...")
 
260
  tokenizer=tokenizer,
261
  )
262
 
263
+ # Afficher les parametres entrainables
264
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
265
  total_params = sum(p.numel() for p in model.parameters())
266
+ print(f"\nParametres entrainables: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
267
 
268
  print("\n" + "=" * 60)
269
+ print("DEMARRAGE DE L'ENTRAINEMENT")
270
  print("=" * 60)
271
 
272
  trainer.train()
 
275
  # SAUVEGARDE
276
  # ============================================================================
277
 
278
+ print("\nSauvegarde du modele...")
279
  trainer.save_model(f"{OUTPUT_DIR}/final")
280
 
281
  if hf_token:
282
  print(f"Push vers {HF_REPO}...")
283
  trainer.push_to_hub()
284
+ print(f"Modele disponible sur: https://huggingface.co/{HF_REPO}")
285
 
286
  print("\n" + "=" * 60)
287
+ print("ENTRAINEMENT TERMINE")
288
  print("=" * 60)