base-d20-demo / app.py
multimodalart's picture
Improve UI/UX a bit
c9a4dbc verified
import gradio as gr
import spaces
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
model_dir = "nanochat-students/base-d20"
# Load model via Transformers Auto classes
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
# Set device explicitly
# Load model and move to device
# Use low_cpu_mem_usage=False to avoid meta device issues
model = AutoModel.from_pretrained(
model_dir,
trust_remote_code=True,
low_cpu_mem_usage=False
)
model = model.to(device)
model.eval()
# Load tokenizer via AutoTokenizer (trust_remote_code uses tokenizer_nanogpt)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, config=config)
return tokenizer, model
tokenizer, model = load_model()
@spaces.GPU
def generate(prompt):
input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id())
ids = torch.tensor([input_ids], dtype=torch.long, device=device)
max_new_tokens = 50
with torch.inference_mode():
for _ in range(max_new_tokens):
outputs = model(input_ids=ids)
logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits
# Only take the logits for the last token
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
ids = torch.cat([ids, next_token], dim=1)
# Optional: Add early stopping on EOS token
# if next_token.item() == eos_token_id:
# break
decoded = tokenizer.decode(ids[0].tolist())
return decoded
gr.Interface(
fn=generate,
inputs=gr.Text(label="Input", lines=15),
outputs=gr.Text(label="Output", lines=15),
).launch()