Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer | |
| from model import SmolLMForCausalLM, SmolLMConfig | |
| import os | |
| # 1. Configuration constants | |
| MODEL_CHECKPOINT = "model.pt" # Expects the model weights to be in this file | |
| TOKENIZER_ID = "HuggingFaceTB/SmolLM-135M" # Using the standard tokenizer | |
| DEVICE = "cpu" # HF Spaces free tier usually is CPU. Change to 'cuda' if GPU is available. | |
| # 2. Load Model and Tokenizer | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID) | |
| print("Initializing model...") | |
| config = SmolLMConfig() | |
| model = SmolLMForCausalLM(config) | |
| # 3. Load Weights | |
| if os.path.exists(MODEL_CHECKPOINT): | |
| print(f"Loading weights from {MODEL_CHECKPOINT}...") | |
| try: | |
| # Map location to CPU to be safe | |
| checkpoint = torch.load(MODEL_CHECKPOINT, map_location=torch.device('cpu')) | |
| # Check if it's a full checkpoint (dict with 'model_state_dict') or just weights | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| else: | |
| state_dict = checkpoint | |
| # Handle any prefix issues (e.g. if saved from compiled model with '_orig_mod.') | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith("_orig_mod."): | |
| new_state_dict[k[10:]] = v | |
| else: | |
| new_state_dict[k] = v | |
| model.load_state_dict(new_state_dict) | |
| print("Weights loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading weights: {e}") | |
| print("Running with initialized (random) weights for demonstration.") | |
| else: | |
| print(f"Warning: {MODEL_CHECKPOINT} not found! Running with random weights.") | |
| model.to(DEVICE) | |
| model.eval() | |
| # 4. Generation Function | |
| def generate_text(prompt, max_new_tokens, temperature, top_k): | |
| if not prompt: | |
| return "Please enter a prompt." | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE) | |
| # Text Generation Loop | |
| # We implement a simple loop similar to the training script's generate function | |
| # but added temperature and top-k sampling for better variety in the demo. | |
| curr_input_ids = input_ids | |
| with torch.no_grad(): | |
| for _ in range(int(max_new_tokens)): | |
| # Get logits | |
| logits = model(curr_input_ids) | |
| next_token_logits = logits[:, -1, :] | |
| # Apply Temperature | |
| if temperature > 0: | |
| next_token_logits = next_token_logits / temperature | |
| else: | |
| # Greedy decoding if temperature is 0 (or very close) | |
| # Just take argmax, but for code simplicity we'll let multinomial handle it with very high conf or Argmax | |
| next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) | |
| curr_input_ids = torch.cat([curr_input_ids, next_token_id], dim=1) | |
| continue | |
| # Apply Top-K | |
| if top_k > 0: | |
| v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) | |
| next_token_logits[next_token_logits < v[:, [-1]]] = float('-inf') | |
| probs = torch.nn.functional.softmax(next_token_logits, dim=-1) | |
| # Sample | |
| next_token_id = torch.multinomial(probs, num_samples=1) | |
| curr_input_ids = torch.cat([curr_input_ids, next_token_id], dim=1) | |
| # optional: stop if EOS token is generated (if we had one defined and training used it) | |
| # if next_token_id == tokenizer.eos_token_id: | |
| # break | |
| output_text = tokenizer.decode(curr_input_ids[0].tolist(), skip_special_tokens=True) | |
| return output_text | |
| # 5. Build Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# SmolLM-135M Implementation Demo") | |
| gr.Markdown("This is a demo of the 135M parameter transformer model trained from scratch.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox(label="Prompt", placeholder="Once upon a time...", lines=3) | |
| with gr.Row(): | |
| max_tokens = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max New Tokens") | |
| temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature") | |
| top_k = gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-K") | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Generated Text", lines=10) | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[prompt_input, max_tokens, temperature, top_k], | |
| outputs=output | |
| ) | |
| gr.Markdown("### Note on inputs") | |
| gr.Markdown("Because this model is small (135M) and trained on a specific dataset, it may not follow instructions like ChatGPT. It is best at completing text/stories.") | |
| if __name__ == "__main__": | |
| demo.launch() | |