training-scripts / scripts /train_sft_n8n_multitask.py
stmasson's picture
Upload scripts/train_sft_n8n_multitask.py with huggingface_hub
81400b6 verified
#!/usr/bin/env python3
# /// script
# dependencies = [
# "trl>=0.12.0",
# "transformers>=4.46.0",
# "accelerate>=0.24.0",
# "peft>=0.7.0",
# "trackio",
# "bitsandbytes",
# "sentencepiece",
# "protobuf",
# ]
# ///
"""
SFT training for n8n agentic multi-task workflows.
Continues fine-tuning from stmasson/mistral-7b-n8n-thinking-orpo (ORPO-trained model)
on the n8n-agentic-multitask dataset for complex multi-step tasks:
- generate: Create n8n workflows from descriptions
- edit: Modify existing workflows
- fix: Repair broken workflows
- improve: Optimize and enhance workflows
- explain: Describe what workflows do
- debug: Diagnose workflow issues
The model learns to use <thinking> tags for chain-of-thought reasoning
before producing structured JSON outputs.
"""
import trackio
import torch
from datasets import load_dataset
from peft import LoraConfig, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig
# Load multitask dataset - use streaming to avoid schema issues, then convert
print("Loading n8n-agentic-multitask dataset...")
# Load with streaming to handle variable schema in metadata
train_stream = load_dataset(
"stmasson/n8n-agentic-multitask",
data_files="data/multitask_large/train.jsonl",
split="train",
streaming=True
)
eval_stream = load_dataset(
"stmasson/n8n-agentic-multitask",
data_files="data/multitask_large/val.jsonl",
split="train",
streaming=True
)
# Only keep the 'messages' column (required for SFT)
def extract_messages(example):
return {"messages": example["messages"]}
train_dataset = train_stream.map(extract_messages, remove_columns=["task_type", "metadata"])
eval_dataset = eval_stream.map(extract_messages, remove_columns=["task_type", "metadata"])
# Convert streaming to regular dataset (materializes in memory)
from datasets import Dataset
print("Converting streaming dataset to memory...")
train_dataset = Dataset.from_generator(lambda: (x for x in train_dataset))
eval_dataset = Dataset.from_generator(lambda: (x for x in eval_dataset))
print(f"Train: {len(train_dataset)} examples")
print(f"Eval: {len(eval_dataset)} examples")
# Load tokenizer from ORPO-trained model
MODEL_NAME = "stmasson/mistral-7b-n8n-thinking-orpo"
BASE_MODEL = "stmasson/mistral-7b-n8n-workflows"
print(f"Loading tokenizer from {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Step 1: Load base model WITHOUT quantization to merge ORPO adapter
print(f"Loading base model {BASE_MODEL} (full precision for merge)...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa",
)
print(f"Loading ORPO adapter from {MODEL_NAME}...")
model = PeftModel.from_pretrained(base_model, MODEL_NAME)
print("Merging ORPO adapter into base model...")
model = model.merge_and_unload()
print("ORPO adapter merged successfully!")
# Step 2: Prepare for LoRA training with gradient checkpointing
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
# New LoRA configuration for SFT training
lora_config = LoraConfig(
r=32,
lora_alpha=64,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
# SFT training configuration
config = SFTConfig(
# Hub settings
output_dir="mistral-7b-n8n-agentic-multitask",
push_to_hub=True,
hub_model_id="stmasson/mistral-7b-n8n-agentic-multitask",
hub_strategy="every_save",
hub_private_repo=False,
# Training parameters
num_train_epochs=1, # Large dataset, 1 epoch is enough
per_device_train_batch_size=1,
gradient_accumulation_steps=32, # Effective batch size = 32
learning_rate=2e-5, # Lower LR for continued fine-tuning
max_length=4096, # Longer context for complex workflows
# Memory optimization
gradient_checkpointing=True,
bf16=True,
# Logging & checkpointing
logging_steps=25,
save_strategy="steps",
save_steps=500,
save_total_limit=3,
# Evaluation
eval_strategy="steps",
eval_steps=500,
# Optimization
warmup_ratio=0.03,
lr_scheduler_type="cosine",
optim="adamw_8bit",
# Monitoring
report_to="trackio",
project="n8n-agentic-training",
run_name="mistral-7b-multitask-sft",
)
# Initialize trainer
print("Initializing SFT trainer...")
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=lora_config,
args=config,
)
print("Starting SFT training...")
print(f" Base: stmasson/mistral-7b-n8n-thinking-orpo (merged)")
print(f" Dataset: stmasson/n8n-agentic-multitask")
print(f" Output: stmasson/mistral-7b-n8n-agentic-multitask")
print(f" Tasks: generate, edit, fix, improve, explain, debug")
trainer.train()
print("Pushing final model to Hub...")
trainer.push_to_hub()
# Finish Trackio
trackio.finish()
print("Training complete!")
print("Model: https://huggingface.co/stmasson/mistral-7b-n8n-agentic-multitask")
print("Metrics: https://huggingface.co/spaces/stmasson/trackio")