Deepseeklora / train_lora.py
VaibhavHD's picture
Update train_lora.py
ab8f3fb verified
#!/usr/bin/env python3
"""
train_lora.py
- Fine-tune DeepSeek 1.3B with LoRA (QLoRA-ish setup)
- Save adapters using safe_serialization=True -> adapter_model.safetensors
- Upload adapter folder to Hugging Face Hub (VaibhavHD/deepseek-lora-monthly)
- Log metrics/artifact to Weights & Biases
"""
import os
import json
import wandb
import torch
from huggingface_hub import HfApi
from datasets import load_dataset
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
TrainingArguments, Trainer, DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model
# -----------------------------
# Config (edit if needed)
# -----------------------------
HF_REPO = "VaibhavHD/deepseek-lora-monthly" # your HF model repo
MODEL_NAME = "deepseek-ai/deepseek-coder-1.3b-base"
OUT_DIR = "out"
ADAPTER_DIR = os.path.join(OUT_DIR, "lora_adapters")
# env secrets expected:
HF_TOKEN = os.getenv("HF_TOKEN")
WANDB_API_KEY = os.getenv("WANDB_API_KEY")
if WANDB_API_KEY:
wandb.login(key=WANDB_API_KEY)
else:
print("⚠️ WANDB_API_KEY not found in env; continuing without W&B logging.")
# -----------------------------
# Load dataset
# -----------------------------
print("Loading dataset...")
dataset = {}
dataset['train'] = load_dataset("westenfelder/NL2SH-ALFA", "train")["train"]
dataset['test'] = load_dataset("westenfelder/NL2SH-ALFA", "test")["train"]
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
def tokenize_fn(batch):
texts = [f"{nl} => {bash}" for nl, bash in zip(batch["nl"], batch["bash"])]
return tokenizer(texts, truncation=True, padding="max_length", max_length=512)
train = dataset["train"].map(tokenize_fn, batched=True)
test = dataset["test"].map(tokenize_fn, batched=True)
# Optional small-subset for fast runs (uncomment to use)
# train = train.shuffle(seed=42).select(range(200))
# test = test.shuffle(seed=42).select(range(20))
# -----------------------------
# Load base model (half precision)
# -----------------------------
print("Loading base model (may take a moment)...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True
)
# avoid caching issues
model.config.use_cache = False
for p in model.parameters():
p.requires_grad = False
# -----------------------------
# Attach LoRA
# -----------------------------
print("Attaching LoRA adapters...")
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=[
"q_proj", "v_proj", "k_proj", "o_proj",
"gate_proj", "down_proj", "up_proj"
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
# -----------------------------
# Data collator + training args
# -----------------------------
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=OUT_DIR,
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
learning_rate=2e-4,
fp16=True,
save_strategy="epoch",
logging_steps=25,
report_to=["wandb"] if WANDB_API_KEY else [],
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train,
eval_dataset=test,
data_collator=data_collator,
)
# -----------------------------
# Run training
# -----------------------------
print("Starting training...")
if WANDB_API_KEY:
wandb.init(project="deepseek-qlora-monthly", name="deepseek-lite-run")
trainer.train()
# -----------------------------
# Evaluate and save metrics
# -----------------------------
print("Evaluating...")
metrics = trainer.evaluate()
# compute simple "accuracy-like" metric from loss (replace with real metric if you have one)
new_acc = 1.0 - metrics.get("eval_loss", 1.0)
print(f"Eval metrics: {metrics}")
print(f"Pseudo-accuracy (1 - eval_loss): {new_acc:.6f}")
os.makedirs(ADAPTER_DIR, exist_ok=True)
metrics_path = os.path.join(OUT_DIR, "metrics.json")
with open(metrics_path, "w") as f:
json.dump(metrics, f)
if WANDB_API_KEY:
wandb.log({"accuracy": new_acc})
# log artifact
artifact = wandb.Artifact(
name="deepseek-lora-adapters",
type="model",
description="LoRA adapters saved with safe_serialization"
)
# -----------------------------
# Save adapters using safe_serialization
# -----------------------------
print("Saving adapters with safe_serialization=True (produces .safetensors)...")
model.save_pretrained(ADAPTER_DIR, safe_serialization=True)
tokenizer.save_pretrained(ADAPTER_DIR)
# add to wandb artifact directory
if WANDB_API_KEY:
artifact.add_dir(ADAPTER_DIR)
wandb.log_artifact(artifact, aliases=["latest"])
print(f"Adapters saved to: {ADAPTER_DIR}")
print("Files in adapter dir:", os.listdir(ADAPTER_DIR))
# -----------------------------
# Upload to Hugging Face model repo
# -----------------------------
if HF_TOKEN:
print(f"Uploading adapter folder to Hugging Face repo: {HF_REPO}")
api = HfApi()
# upload_folder will overwrite same filenames in the repo
api.upload_folder(
folder_path=ADAPTER_DIR,
path_in_repo=".",
repo_id=HF_REPO,
token=HF_TOKEN
)
print("✅ Upload complete.")
else:
print("⚠️ HF_TOKEN not set. Skipping upload to Hugging Face Hub.")