Deepseeklora / inference.py
VaibhavHD's picture
Update inference.py
5352ede verified
raw
history blame contribute delete
805 Bytes
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import transformers.training_args
# ✅ Fix: allow this class for safe loading in PyTorch 2.6+
torch.serialization.add_safe_globals([transformers.training_args.TrainingArguments])
BASE_MODEL = "deepseek-ai/deepseek-coder-1.3b-base"
LORA_REPO = "VaibhavHD/deepseek-lora-monthly"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
base = AutoModelForCausalLM.from_pretrained(BASE_MODEL, trust_remote_code=True)
model = PeftModel.from_pretrained(base, LORA_REPO)
def generate_response(prompt: str) -> str:
inputs = tokenizer(prompt, return_tensors="pt")
out = model.generate(**inputs, max_new_tokens=200)
return tokenizer.decode(out[0], skip_special_tokens=True)