Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoConfig, AutoModel, AutoTokenizer | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(): | |
| model_dir = "nanochat-students/base-d20" | |
| # Load model via Transformers Auto classes | |
| config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) | |
| # Set device explicitly | |
| # Load model and move to device | |
| # Use low_cpu_mem_usage=False to avoid meta device issues | |
| model = AutoModel.from_pretrained( | |
| model_dir, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=False | |
| ) | |
| model = model.to(device) | |
| model.eval() | |
| # Load tokenizer via AutoTokenizer (trust_remote_code uses tokenizer_nanogpt) | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, config=config) | |
| return tokenizer, model | |
| tokenizer, model = load_model() | |
| def generate(prompt): | |
| input_ids = tokenizer.encode(prompt, prepend=tokenizer.get_bos_token_id()) | |
| ids = torch.tensor([input_ids], dtype=torch.long, device=device) | |
| max_new_tokens = 50 | |
| with torch.inference_mode(): | |
| for _ in range(max_new_tokens): | |
| outputs = model(input_ids=ids) | |
| logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits | |
| # Only take the logits for the last token | |
| next_token_logits = logits[:, -1, :] | |
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) | |
| ids = torch.cat([ids, next_token], dim=1) | |
| # Optional: Add early stopping on EOS token | |
| # if next_token.item() == eos_token_id: | |
| # break | |
| decoded = tokenizer.decode(ids[0].tolist()) | |
| return decoded | |
| gr.Interface( | |
| fn=generate, | |
| inputs=gr.Text(label="Input", lines=15), | |
| outputs=gr.Text(label="Output", lines=15), | |
| ).launch() |