training-scripts / scripts /train_n8n_sft.py
stmasson's picture
Upload scripts/train_n8n_sft.py with huggingface_hub
f5c4f4a verified
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "transformers>=4.45.0",
# "trl>=0.12.0",
# "peft>=0.13.0",
# "datasets>=3.0.0",
# "accelerate>=1.0.0",
# "bitsandbytes>=0.44.0",
# "huggingface_hub>=0.26.0",
# "torch>=2.4.0",
# "einops>=0.8.0",
# "sentencepiece>=0.2.0",
# ]
# [tool.uv]
# index-strategy = "unsafe-best-match"
# extra-index-url = ["https://download.pytorch.org/whl/cu124"]
# ///
"""
Script d'entraînement SFT pour le modèle n8n Expert.
Usage sur HuggingFace Jobs:
hf jobs uv run \
--script train_n8n_sft.py \
--flavor h100x1 \
--name n8n-expert-sft \
--timeout 24h
Variables d'environnement requises:
- HF_TOKEN: Token HuggingFace avec accès en écriture
- WANDB_API_KEY: (optionnel) Pour le tracking W&B
"""
import os
import json
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig
from huggingface_hub import login, hf_hub_download
# ============================================================================
# CONFIGURATION
# ============================================================================
# Modele de base
MODEL_NAME = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-14B-Instruct")
# Dataset
DATASET_REPO = "stmasson/n8n-agentic-multitask"
TRAIN_FILE = "data/multitask_large/train.jsonl"
VAL_FILE = "data/multitask_large/val.jsonl"
# Output
OUTPUT_DIR = "./n8n-expert-sft"
HF_REPO = os.environ.get("HF_REPO", "stmasson/n8n-expert-14b-sft")
# Hyperparametres
NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "3"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "2"))
GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-5"))
MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "8192"))
# LoRA
LORA_R = int(os.environ.get("LORA_R", "64"))
LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "128"))
LORA_DROPOUT = float(os.environ.get("LORA_DROPOUT", "0.05"))
# Quantization (pour economiser la VRAM) - 4-bit par defaut pour H100
USE_4BIT = os.environ.get("USE_4BIT", "true").lower() == "true"
# ============================================================================
# AUTHENTIFICATION
# ============================================================================
print("=" * 60)
print("ENTRAINEMENT SFT - N8N EXPERT")
print("=" * 60)
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("Authentifie sur HuggingFace")
else:
print("Warning: HF_TOKEN non defini, push desactive")
# Desactive wandb pour eviter les conflits de dependances
report_to = "none"
print("Tracking desactive (pas de wandb)")
# ============================================================================
# CHARGEMENT DU MODELE
# ============================================================================
print(f"\nChargement du modele: {MODEL_NAME}")
# Configuration quantization si necessaire
if USE_4BIT:
print("Mode 4-bit active")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
model = prepare_model_for_kbit_training(model)
else:
print("Mode bfloat16")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa", # SDPA is built into PyTorch, no extra install needed
device_map="auto",
trust_remote_code=True,
)
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print(f"Modele charge: {model.config.num_hidden_layers} layers, {model.config.hidden_size} hidden size")
# ============================================================================
# CONFIGURATION LORA
# ============================================================================
print(f"\nConfiguration LoRA: r={LORA_R}, alpha={LORA_ALPHA}")
lora_config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM"
)
# ============================================================================
# CHARGEMENT DU DATASET (FIX: chargement direct JSON pour eviter les conflits de schema)
# ============================================================================
print(f"\nChargement du dataset: {DATASET_REPO}")
def load_jsonl_dataset(repo_id: str, filename: str) -> Dataset:
"""
Charge un dataset JSONL directement en ne gardant que la colonne 'messages'.
Evite les problemes de schema avec les colonnes struct comme 'nodes_used'.
"""
# Telecharger le fichier
local_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="dataset"
)
# Lire le JSONL et extraire uniquement 'messages'
messages_list = []
with open(local_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
messages_list.append({"messages": data["messages"]})
# Creer le Dataset
return Dataset.from_list(messages_list)
# Charger train et validation
train_dataset = load_jsonl_dataset(DATASET_REPO, TRAIN_FILE)
val_dataset = load_jsonl_dataset(DATASET_REPO, VAL_FILE)
print(f"Train: {len(train_dataset)} exemples")
print(f"Validation: {len(val_dataset)} exemples")
# Fonction de formatage
def format_example(example):
"""Formate les messages en texte pour l'entrainement"""
messages = example["messages"]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
return {"text": text}
# Appliquer le formatage
print("Formatage des donnees...")
train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)
val_dataset = val_dataset.map(format_example, remove_columns=val_dataset.column_names)
# Afficher un exemple
print("\nExemple de donnees formatees:")
print(train_dataset[0]["text"][:500] + "...")
# ============================================================================
# CONFIGURATION D'ENTRAINEMENT
# ============================================================================
print(f"\nConfiguration d'entrainement:")
print(f" - Epochs: {NUM_EPOCHS}")
print(f" - Batch size: {BATCH_SIZE}")
print(f" - Gradient accumulation: {GRAD_ACCUM}")
print(f" - Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")
print(f" - Learning rate: {LEARNING_RATE}")
print(f" - Max sequence length: {MAX_SEQ_LENGTH}")
training_args = SFTConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LEARNING_RATE,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
weight_decay=0.01,
bf16=True,
tf32=True,
logging_steps=10,
save_strategy="steps",
save_steps=500,
save_total_limit=3,
eval_strategy="steps",
eval_steps=500,
max_length=MAX_SEQ_LENGTH, # renamed from max_seq_length in TRL 0.12+
packing=False, # Disabled: packing requires flash attention for proper cross-attention masking
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
dataset_text_field="text",
report_to=report_to,
run_name="n8n-expert-sft",
hub_model_id=HF_REPO if hf_token else None,
push_to_hub=bool(hf_token),
hub_strategy="checkpoint",
)
# ============================================================================
# ENTRAINEMENT
# ============================================================================
print("\nInitialisation du trainer...")
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
peft_config=lora_config,
processing_class=tokenizer, # renamed from tokenizer in TRL 0.12+
)
# Afficher les parametres entrainables
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nParametres entrainables: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
print("\n" + "=" * 60)
print("DEMARRAGE DE L'ENTRAINEMENT")
print("=" * 60)
trainer.train()
# ============================================================================
# SAUVEGARDE
# ============================================================================
print("\nSauvegarde du modele...")
trainer.save_model(f"{OUTPUT_DIR}/final")
if hf_token:
print(f"Push vers {HF_REPO}...")
trainer.push_to_hub()
print(f"Modele disponible sur: https://huggingface.co/{HF_REPO}")
print("\n" + "=" * 60)
print("ENTRAINEMENT TERMINE")
print("=" * 60)