File size: 9,371 Bytes
807886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d492a0
807886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471ffcc
807886d
 
 
471ffcc
807886d
 
 
 
 
471ffcc
807886d
 
 
 
 
 
 
 
 
 
 
471ffcc
807886d
 
 
 
 
 
 
 
 
 
 
f5c4f4a
 
807886d
 
 
 
 
 
471ffcc
807886d
 
 
 
 
471ffcc
807886d
471ffcc
807886d
471ffcc
6d492a0
471ffcc
807886d
 
471ffcc
807886d
 
471ffcc
807886d
471ffcc
807886d
471ffcc
807886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517aef6
807886d
 
 
 
 
 
 
 
 
 
471ffcc
807886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471ffcc
807886d
 
 
 
471ffcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807886d
80fe464
 
807886d
 
 
471ffcc
807886d
 
 
 
 
 
 
 
 
471ffcc
80fe464
 
807886d
 
471ffcc
80fe464
807886d
 
471ffcc
807886d
 
471ffcc
807886d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d48090
f5c4f4a
807886d
 
 
 
 
 
 
 
 
 
 
471ffcc
807886d
 
 
 
 
 
 
80fe464
 
807886d
e388b5a
807886d
 
471ffcc
807886d
 
471ffcc
807886d
 
471ffcc
807886d
 
 
 
 
 
 
 
471ffcc
807886d
 
 
 
 
471ffcc
807886d
 
471ffcc
807886d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "transformers>=4.45.0",
#     "trl>=0.12.0",
#     "peft>=0.13.0",
#     "datasets>=3.0.0",
#     "accelerate>=1.0.0",
#     "bitsandbytes>=0.44.0",
#     "huggingface_hub>=0.26.0",
#     "torch>=2.4.0",
#     "einops>=0.8.0",
#     "sentencepiece>=0.2.0",
# ]
# [tool.uv]
# index-strategy = "unsafe-best-match"
# extra-index-url = ["https://download.pytorch.org/whl/cu124"]
# ///
"""
Script d'entraînement SFT pour le modèle n8n Expert.

Usage sur HuggingFace Jobs:
    hf jobs uv run \
        --script train_n8n_sft.py \
        --flavor h100x1 \
        --name n8n-expert-sft \
        --timeout 24h

Variables d'environnement requises:
    - HF_TOKEN: Token HuggingFace avec accès en écriture
    - WANDB_API_KEY: (optionnel) Pour le tracking W&B
"""

import os
import json
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig
from huggingface_hub import login, hf_hub_download

# ============================================================================
# CONFIGURATION
# ============================================================================

# Modele de base
MODEL_NAME = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-14B-Instruct")

# Dataset
DATASET_REPO = "stmasson/n8n-agentic-multitask"
TRAIN_FILE = "data/multitask_large/train.jsonl"
VAL_FILE = "data/multitask_large/val.jsonl"

# Output
OUTPUT_DIR = "./n8n-expert-sft"
HF_REPO = os.environ.get("HF_REPO", "stmasson/n8n-expert-14b-sft")

# Hyperparametres
NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "3"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "2"))
GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-5"))
MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "8192"))

# LoRA
LORA_R = int(os.environ.get("LORA_R", "64"))
LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "128"))
LORA_DROPOUT = float(os.environ.get("LORA_DROPOUT", "0.05"))

# Quantization (pour economiser la VRAM) - 4-bit par defaut pour H100
USE_4BIT = os.environ.get("USE_4BIT", "true").lower() == "true"

# ============================================================================
# AUTHENTIFICATION
# ============================================================================

print("=" * 60)
print("ENTRAINEMENT SFT - N8N EXPERT")
print("=" * 60)

hf_token = os.environ.get("HF_TOKEN")
if hf_token:
    login(token=hf_token)
    print("Authentifie sur HuggingFace")
else:
    print("Warning: HF_TOKEN non defini, push desactive")

# Desactive wandb pour eviter les conflits de dependances
report_to = "none"
print("Tracking desactive (pas de wandb)")

# ============================================================================
# CHARGEMENT DU MODELE
# ============================================================================

print(f"\nChargement du modele: {MODEL_NAME}")

# Configuration quantization si necessaire
if USE_4BIT:
    print("Mode 4-bit active")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True,
    )
    model = prepare_model_for_kbit_training(model)
else:
    print("Mode bfloat16")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        attn_implementation="sdpa",  # SDPA is built into PyTorch, no extra install needed
        device_map="auto",
        trust_remote_code=True,
    )

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"Modele charge: {model.config.num_hidden_layers} layers, {model.config.hidden_size} hidden size")

# ============================================================================
# CONFIGURATION LORA
# ============================================================================

print(f"\nConfiguration LoRA: r={LORA_R}, alpha={LORA_ALPHA}")

lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM"
)

# ============================================================================
# CHARGEMENT DU DATASET (FIX: chargement direct JSON pour eviter les conflits de schema)
# ============================================================================

print(f"\nChargement du dataset: {DATASET_REPO}")

def load_jsonl_dataset(repo_id: str, filename: str) -> Dataset:
    """
    Charge un dataset JSONL directement en ne gardant que la colonne 'messages'.
    Evite les problemes de schema avec les colonnes struct comme 'nodes_used'.
    """
    # Telecharger le fichier
    local_path = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        repo_type="dataset"
    )

    # Lire le JSONL et extraire uniquement 'messages'
    messages_list = []
    with open(local_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            messages_list.append({"messages": data["messages"]})

    # Creer le Dataset
    return Dataset.from_list(messages_list)

# Charger train et validation
train_dataset = load_jsonl_dataset(DATASET_REPO, TRAIN_FILE)
val_dataset = load_jsonl_dataset(DATASET_REPO, VAL_FILE)

print(f"Train: {len(train_dataset)} exemples")
print(f"Validation: {len(val_dataset)} exemples")

# Fonction de formatage
def format_example(example):
    """Formate les messages en texte pour l'entrainement"""
    messages = example["messages"]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False
    )
    return {"text": text}

# Appliquer le formatage
print("Formatage des donnees...")
train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)
val_dataset = val_dataset.map(format_example, remove_columns=val_dataset.column_names)

# Afficher un exemple
print("\nExemple de donnees formatees:")
print(train_dataset[0]["text"][:500] + "...")

# ============================================================================
# CONFIGURATION D'ENTRAINEMENT
# ============================================================================

print(f"\nConfiguration d'entrainement:")
print(f"  - Epochs: {NUM_EPOCHS}")
print(f"  - Batch size: {BATCH_SIZE}")
print(f"  - Gradient accumulation: {GRAD_ACCUM}")
print(f"  - Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")
print(f"  - Learning rate: {LEARNING_RATE}")
print(f"  - Max sequence length: {MAX_SEQ_LENGTH}")

training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.01,
    bf16=True,
    tf32=True,
    logging_steps=10,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,
    eval_strategy="steps",
    eval_steps=500,
    max_length=MAX_SEQ_LENGTH,  # renamed from max_seq_length in TRL 0.12+
    packing=False,  # Disabled: packing requires flash attention for proper cross-attention masking
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_text_field="text",
    report_to=report_to,
    run_name="n8n-expert-sft",
    hub_model_id=HF_REPO if hf_token else None,
    push_to_hub=bool(hf_token),
    hub_strategy="checkpoint",
)

# ============================================================================
# ENTRAINEMENT
# ============================================================================

print("\nInitialisation du trainer...")

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    peft_config=lora_config,
    processing_class=tokenizer,  # renamed from tokenizer in TRL 0.12+
)

# Afficher les parametres entrainables
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nParametres entrainables: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")

print("\n" + "=" * 60)
print("DEMARRAGE DE L'ENTRAINEMENT")
print("=" * 60)

trainer.train()

# ============================================================================
# SAUVEGARDE
# ============================================================================

print("\nSauvegarde du modele...")
trainer.save_model(f"{OUTPUT_DIR}/final")

if hf_token:
    print(f"Push vers {HF_REPO}...")
    trainer.push_to_hub()
    print(f"Modele disponible sur: https://huggingface.co/{HF_REPO}")

print("\n" + "=" * 60)
print("ENTRAINEMENT TERMINE")
print("=" * 60)