from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel import torch app = FastAPI() # === MODEL === MODEL_REPO = "sahil239/falcon-lora-chatbot" # replace with your HF repo BASE_MODEL = "tiiuae/falcon-rw-1b" # === Load tokenizer === tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token # required to avoid padding error # === Load base model and merge LoRA === base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, trust_remote_code=True) model = PeftModel.from_pretrained(base_model, MODEL_REPO) model.eval() # === Move to GPU if available === device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # === Request Schema === class PromptRequest(BaseModel): prompt: str max_new_tokens: int = 200 temperature: float = 0.7 top_p: float = 0.95 @app.get("/") def health_check(): return {"status": "running"} @app.post("/generate") async def generate_text(req: PromptRequest): inputs = tokenizer( req.prompt, return_tensors="pt", padding=True, truncation=True, max_length=200 ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, # 🚨 Helps stop when sentence is "done" repetition_penalty=1.2, # 🚫 Penalizes repeating phrases no_repeat_ngram_size=3 ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"response": generated_text[len(req.prompt):].strip()} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)