|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from peft import PeftModel |
|
|
import torch |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
MODEL_REPO = "sahil239/falcon-lora-chatbot" |
|
|
BASE_MODEL = "tiiuae/falcon-rw-1b" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, trust_remote_code=True) |
|
|
model = PeftModel.from_pretrained(base_model, MODEL_REPO) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
class PromptRequest(BaseModel): |
|
|
prompt: str |
|
|
max_new_tokens: int = 200 |
|
|
temperature: float = 0.7 |
|
|
top_p: float = 0.95 |
|
|
|
|
|
@app.get("/") |
|
|
def health_check(): |
|
|
return {"status": "running"} |
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate_text(req: PromptRequest): |
|
|
inputs = tokenizer( |
|
|
req.prompt, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=200 |
|
|
) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
input_ids=inputs["input_ids"], |
|
|
attention_mask=inputs["attention_mask"], |
|
|
max_new_tokens=req.max_new_tokens, |
|
|
temperature=req.temperature, |
|
|
top_p=req.top_p, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
repetition_penalty=1.2, |
|
|
no_repeat_ngram_size=3 |
|
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return {"response": generated_text[len(req.prompt):].strip()} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |