app-otqmig-46 / models.py
AiCoderv2's picture
Update Gradio app with multiple files
3720b00 verified
raw
history blame
1.73 kB
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
# Cache for loaded models
model_cache = {}
def load_model(model_name):
"""Load and cache a Hugging Face model."""
if model_name not in model_cache:
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # Use float16 for efficiency
device_map="auto" # Auto-detect GPU
)
model_cache[model_name] = {
'name': model_name,
'tokenizer': tokenizer,
'model': model
}
except Exception as e:
raise ValueError(f"Failed to load model {model_name}: {str(e)}")
return model_cache[model_name]
def chat_with_model(model_data, conversation, streaming=False):
"""Generate response using the loaded model."""
try:
tokenizer = model_data['tokenizer']
model = model_data['model']
inputs = tokenizer(conversation, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=inputs['input_ids'].shape[1] + 100, # Generate up to 100 new tokens
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return response.strip()
except Exception as e:
return f"Error generating response: {str(e)}"