File size: 5,343 Bytes
334164b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81400b6
334164b
81400b6
 
 
334164b
 
81400b6
 
334164b
81400b6
334164b
 
81400b6
 
334164b
 
81400b6
 
 
 
 
 
 
 
 
 
 
 
 
334164b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#!/usr/bin/env python3
# /// script
# dependencies = [
#     "trl>=0.12.0",
#     "transformers>=4.46.0",
#     "accelerate>=0.24.0",
#     "peft>=0.7.0",
#     "trackio",
#     "bitsandbytes",
#     "sentencepiece",
#     "protobuf",
# ]
# ///

"""
SFT training for n8n agentic multi-task workflows.

Continues fine-tuning from stmasson/mistral-7b-n8n-thinking-orpo (ORPO-trained model)
on the n8n-agentic-multitask dataset for complex multi-step tasks:
- generate: Create n8n workflows from descriptions
- edit: Modify existing workflows
- fix: Repair broken workflows
- improve: Optimize and enhance workflows
- explain: Describe what workflows do
- debug: Diagnose workflow issues

The model learns to use <thinking> tags for chain-of-thought reasoning
before producing structured JSON outputs.
"""

import trackio
import torch
from datasets import load_dataset
from peft import LoraConfig, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig


# Load multitask dataset - use streaming to avoid schema issues, then convert
print("Loading n8n-agentic-multitask dataset...")

# Load with streaming to handle variable schema in metadata
train_stream = load_dataset(
    "stmasson/n8n-agentic-multitask",
    data_files="data/multitask_large/train.jsonl",
    split="train",
    streaming=True
)
eval_stream = load_dataset(
    "stmasson/n8n-agentic-multitask",
    data_files="data/multitask_large/val.jsonl",
    split="train",
    streaming=True
)

# Only keep the 'messages' column (required for SFT)
def extract_messages(example):
    return {"messages": example["messages"]}

train_dataset = train_stream.map(extract_messages, remove_columns=["task_type", "metadata"])
eval_dataset = eval_stream.map(extract_messages, remove_columns=["task_type", "metadata"])

# Convert streaming to regular dataset (materializes in memory)
from datasets import Dataset
print("Converting streaming dataset to memory...")
train_dataset = Dataset.from_generator(lambda: (x for x in train_dataset))
eval_dataset = Dataset.from_generator(lambda: (x for x in eval_dataset))

print(f"Train: {len(train_dataset)} examples")
print(f"Eval: {len(eval_dataset)} examples")

# Load tokenizer from ORPO-trained model
MODEL_NAME = "stmasson/mistral-7b-n8n-thinking-orpo"
BASE_MODEL = "stmasson/mistral-7b-n8n-workflows"

print(f"Loading tokenizer from {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Step 1: Load base model WITHOUT quantization to merge ORPO adapter
print(f"Loading base model {BASE_MODEL} (full precision for merge)...")
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa",
)

print(f"Loading ORPO adapter from {MODEL_NAME}...")
model = PeftModel.from_pretrained(base_model, MODEL_NAME)

print("Merging ORPO adapter into base model...")
model = model.merge_and_unload()
print("ORPO adapter merged successfully!")

# Step 2: Prepare for LoRA training with gradient checkpointing
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# New LoRA configuration for SFT training
lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

# SFT training configuration
config = SFTConfig(
    # Hub settings
    output_dir="mistral-7b-n8n-agentic-multitask",
    push_to_hub=True,
    hub_model_id="stmasson/mistral-7b-n8n-agentic-multitask",
    hub_strategy="every_save",
    hub_private_repo=False,

    # Training parameters
    num_train_epochs=1,  # Large dataset, 1 epoch is enough
    per_device_train_batch_size=1,
    gradient_accumulation_steps=32,  # Effective batch size = 32
    learning_rate=2e-5,  # Lower LR for continued fine-tuning
    max_length=4096,  # Longer context for complex workflows

    # Memory optimization
    gradient_checkpointing=True,
    bf16=True,

    # Logging & checkpointing
    logging_steps=25,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=3,

    # Evaluation
    eval_strategy="steps",
    eval_steps=500,

    # Optimization
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",

    # Monitoring
    report_to="trackio",
    project="n8n-agentic-training",
    run_name="mistral-7b-multitask-sft",
)

# Initialize trainer
print("Initializing SFT trainer...")
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=lora_config,
    args=config,
)

print("Starting SFT training...")
print(f"  Base: stmasson/mistral-7b-n8n-thinking-orpo (merged)")
print(f"  Dataset: stmasson/n8n-agentic-multitask")
print(f"  Output: stmasson/mistral-7b-n8n-agentic-multitask")
print(f"  Tasks: generate, edit, fix, improve, explain, debug")

trainer.train()

print("Pushing final model to Hub...")
trainer.push_to_hub()

# Finish Trackio
trackio.finish()

print("Training complete!")
print("Model: https://huggingface.co/stmasson/mistral-7b-n8n-agentic-multitask")
print("Metrics: https://huggingface.co/spaces/stmasson/trackio")