smollm2-135M / app.py
nitinvig's picture
Upload 4 files
89dca07 verified
raw
history blame
5.11 kB
import torch
import gradio as gr
from transformers import AutoTokenizer
from model import SmolLMForCausalLM, SmolLMConfig
import os
# 1. Configuration constants
MODEL_CHECKPOINT = "model.pt" # Expects the model weights to be in this file
TOKENIZER_ID = "HuggingFaceTB/SmolLM-135M" # Using the standard tokenizer
DEVICE = "cpu" # HF Spaces free tier usually is CPU. Change to 'cuda' if GPU is available.
# 2. Load Model and Tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID)
print("Initializing model...")
config = SmolLMConfig()
model = SmolLMForCausalLM(config)
# 3. Load Weights
if os.path.exists(MODEL_CHECKPOINT):
print(f"Loading weights from {MODEL_CHECKPOINT}...")
try:
# Map location to CPU to be safe
checkpoint = torch.load(MODEL_CHECKPOINT, map_location=torch.device('cpu'))
# Check if it's a full checkpoint (dict with 'model_state_dict') or just weights
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
# Handle any prefix issues (e.g. if saved from compiled model with '_orig_mod.')
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith("_orig_mod."):
new_state_dict[k[10:]] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
print("Weights loaded successfully.")
except Exception as e:
print(f"Error loading weights: {e}")
print("Running with initialized (random) weights for demonstration.")
else:
print(f"Warning: {MODEL_CHECKPOINT} not found! Running with random weights.")
model.to(DEVICE)
model.eval()
# 4. Generation Function
def generate_text(prompt, max_new_tokens, temperature, top_k):
if not prompt:
return "Please enter a prompt."
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
# Text Generation Loop
# We implement a simple loop similar to the training script's generate function
# but added temperature and top-k sampling for better variety in the demo.
curr_input_ids = input_ids
with torch.no_grad():
for _ in range(int(max_new_tokens)):
# Get logits
logits = model(curr_input_ids)
next_token_logits = logits[:, -1, :]
# Apply Temperature
if temperature > 0:
next_token_logits = next_token_logits / temperature
else:
# Greedy decoding if temperature is 0 (or very close)
# Just take argmax, but for code simplicity we'll let multinomial handle it with very high conf or Argmax
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
curr_input_ids = torch.cat([curr_input_ids, next_token_id], dim=1)
continue
# Apply Top-K
if top_k > 0:
v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
next_token_logits[next_token_logits < v[:, [-1]]] = float('-inf')
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
# Sample
next_token_id = torch.multinomial(probs, num_samples=1)
curr_input_ids = torch.cat([curr_input_ids, next_token_id], dim=1)
# optional: stop if EOS token is generated (if we had one defined and training used it)
# if next_token_id == tokenizer.eos_token_id:
# break
output_text = tokenizer.decode(curr_input_ids[0].tolist(), skip_special_tokens=True)
return output_text
# 5. Build Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# SmolLM-135M Implementation Demo")
gr.Markdown("This is a demo of the 135M parameter transformer model trained from scratch.")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt", placeholder="Once upon a time...", lines=3)
with gr.Row():
max_tokens = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max New Tokens")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature")
top_k = gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-K")
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output = gr.Textbox(label="Generated Text", lines=10)
generate_btn.click(
fn=generate_text,
inputs=[prompt_input, max_tokens, temperature, top_k],
outputs=output
)
gr.Markdown("### Note on inputs")
gr.Markdown("Because this model is small (135M) and trained on a specific dataset, it may not follow instructions like ChatGPT. It is best at completing text/stories.")
if __name__ == "__main__":
demo.launch()