Deepseeklora / train_lora.py
VaibhavHD's picture
Upload 6 files
70c6845 verified
raw
history blame
2.49 kB
import os, json, torch, wandb
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
TrainingArguments, DataCollatorForLanguageModeling)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from huggingface_hub import HfApi
HF_TOKEN = os.getenv("HF_TOKEN")
WANDB_API_KEY = os.getenv("WANDB_API_KEY")
wandb.login(key=WANDB_API_KEY)
model_name = "deepseek-ai/deepseek-coder-1.3b-base"
dataset = load_dataset("westenfelder/NL2SH-ALFA")
tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
def tok_fn(b): return tok([f"{n} => {bsh}" for n,bsh in zip(b['nl'],b['bash'])],
truncation=True,padding="max_length",max_length=512)
train, test = dataset["train"].map(tok_fn,batched=True), dataset["test"].map(tok_fn,batched=True)
m = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16,
low_cpu_mem_usage=True, device_map="auto",
trust_remote_code=True)
m.config.use_cache=False
for p in m.parameters(): p.requires_grad=False
cfg=LoraConfig(r=8,lora_alpha=16,target_modules=["q_proj","v_proj","k_proj","o_proj",
"gate_proj","down_proj","up_proj"],
lora_dropout=0.05,bias="none",task_type="CAUSAL_LM")
m=get_peft_model(m,cfg)
coll=DataCollatorForLanguageModeling(tokenizer=tok,mlm=False)
args=TrainingArguments(output_dir="./out",num_train_epochs=1,per_device_train_batch_size=1,
gradient_accumulation_steps=8,learning_rate=2e-4,fp16=True,
save_strategy="epoch",logging_steps=25,report_to=["wandb"])
t=Trainer(model=m,args=args,train_dataset=train,eval_dataset=test,data_collator=coll)
wandb.init(project="deepseek-qlora-monthly",name="deepseek-lite-run")
t.train()
metrics=t.evaluate(); acc=1-metrics.get("eval_loss",1)
with open("out/metrics.json","w") as f: json.dump(metrics,f)
wandb.log({"accuracy":acc})
print(f"βœ… Eval accuracy {acc:.4f}")
ad="out/lora_adapters"; os.makedirs(ad,exist_ok=True)
m.save_pretrained(ad); tok.save_pretrained(ad)
artifact=wandb.Artifact("deepseek-lora-adapters","model"); artifact.add_dir(ad); wandb.log_artifact(artifact)
api=HfApi(token=HF_TOKEN)
api.upload_folder(folder_path=ad,repo_id="your-username/deepseek-lora-monthly",path_in_repo=".")
print("βœ… Uploaded to HF Hub")