File size: 5,343 Bytes
334164b 81400b6 334164b 81400b6 334164b 81400b6 334164b 81400b6 334164b 81400b6 334164b 81400b6 334164b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
#!/usr/bin/env python3
# /// script
# dependencies = [
# "trl>=0.12.0",
# "transformers>=4.46.0",
# "accelerate>=0.24.0",
# "peft>=0.7.0",
# "trackio",
# "bitsandbytes",
# "sentencepiece",
# "protobuf",
# ]
# ///
"""
SFT training for n8n agentic multi-task workflows.
Continues fine-tuning from stmasson/mistral-7b-n8n-thinking-orpo (ORPO-trained model)
on the n8n-agentic-multitask dataset for complex multi-step tasks:
- generate: Create n8n workflows from descriptions
- edit: Modify existing workflows
- fix: Repair broken workflows
- improve: Optimize and enhance workflows
- explain: Describe what workflows do
- debug: Diagnose workflow issues
The model learns to use <thinking> tags for chain-of-thought reasoning
before producing structured JSON outputs.
"""
import trackio
import torch
from datasets import load_dataset
from peft import LoraConfig, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig
# Load multitask dataset - use streaming to avoid schema issues, then convert
print("Loading n8n-agentic-multitask dataset...")
# Load with streaming to handle variable schema in metadata
train_stream = load_dataset(
"stmasson/n8n-agentic-multitask",
data_files="data/multitask_large/train.jsonl",
split="train",
streaming=True
)
eval_stream = load_dataset(
"stmasson/n8n-agentic-multitask",
data_files="data/multitask_large/val.jsonl",
split="train",
streaming=True
)
# Only keep the 'messages' column (required for SFT)
def extract_messages(example):
return {"messages": example["messages"]}
train_dataset = train_stream.map(extract_messages, remove_columns=["task_type", "metadata"])
eval_dataset = eval_stream.map(extract_messages, remove_columns=["task_type", "metadata"])
# Convert streaming to regular dataset (materializes in memory)
from datasets import Dataset
print("Converting streaming dataset to memory...")
train_dataset = Dataset.from_generator(lambda: (x for x in train_dataset))
eval_dataset = Dataset.from_generator(lambda: (x for x in eval_dataset))
print(f"Train: {len(train_dataset)} examples")
print(f"Eval: {len(eval_dataset)} examples")
# Load tokenizer from ORPO-trained model
MODEL_NAME = "stmasson/mistral-7b-n8n-thinking-orpo"
BASE_MODEL = "stmasson/mistral-7b-n8n-workflows"
print(f"Loading tokenizer from {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Step 1: Load base model WITHOUT quantization to merge ORPO adapter
print(f"Loading base model {BASE_MODEL} (full precision for merge)...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa",
)
print(f"Loading ORPO adapter from {MODEL_NAME}...")
model = PeftModel.from_pretrained(base_model, MODEL_NAME)
print("Merging ORPO adapter into base model...")
model = model.merge_and_unload()
print("ORPO adapter merged successfully!")
# Step 2: Prepare for LoRA training with gradient checkpointing
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
# New LoRA configuration for SFT training
lora_config = LoraConfig(
r=32,
lora_alpha=64,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
# SFT training configuration
config = SFTConfig(
# Hub settings
output_dir="mistral-7b-n8n-agentic-multitask",
push_to_hub=True,
hub_model_id="stmasson/mistral-7b-n8n-agentic-multitask",
hub_strategy="every_save",
hub_private_repo=False,
# Training parameters
num_train_epochs=1, # Large dataset, 1 epoch is enough
per_device_train_batch_size=1,
gradient_accumulation_steps=32, # Effective batch size = 32
learning_rate=2e-5, # Lower LR for continued fine-tuning
max_length=4096, # Longer context for complex workflows
# Memory optimization
gradient_checkpointing=True,
bf16=True,
# Logging & checkpointing
logging_steps=25,
save_strategy="steps",
save_steps=500,
save_total_limit=3,
# Evaluation
eval_strategy="steps",
eval_steps=500,
# Optimization
warmup_ratio=0.03,
lr_scheduler_type="cosine",
optim="adamw_8bit",
# Monitoring
report_to="trackio",
project="n8n-agentic-training",
run_name="mistral-7b-multitask-sft",
)
# Initialize trainer
print("Initializing SFT trainer...")
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=lora_config,
args=config,
)
print("Starting SFT training...")
print(f" Base: stmasson/mistral-7b-n8n-thinking-orpo (merged)")
print(f" Dataset: stmasson/n8n-agentic-multitask")
print(f" Output: stmasson/mistral-7b-n8n-agentic-multitask")
print(f" Tasks: generate, edit, fix, improve, explain, debug")
trainer.train()
print("Pushing final model to Hub...")
trainer.push_to_hub()
# Finish Trackio
trackio.finish()
print("Training complete!")
print("Model: https://huggingface.co/stmasson/mistral-7b-n8n-agentic-multitask")
print("Metrics: https://huggingface.co/spaces/stmasson/trackio")
|