import gradio as gr import spaces # HF Spaces ZeroGPU decorator - only available in HF Spaces environment import torch import torch.nn.functional as F import os import sys from huggingface_hub import login from config import ModelArgs, get_args from model import DeepSeekV3, initialize_tokenizer from tokenizer import Tokenizer from inference import topk_sampling # Global variables tk = None model = None model_args = None # Model paths - using the checkpoint in the HF Space model_paths = { "Checkpoint 2000": "./checkpoint_2000.pt", } def initialize_app(): """Initialize the app with tokenizer and model args""" global tk, model_args # Initialize model args model_args = ModelArgs() # Get HF token from environment variables (set in HF Spaces secrets) hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_HUB_TOKEN') if not hf_token: print("Warning: No HF_TOKEN found in environment variables.") print("Please set HF_TOKEN in your Hugging Face Spaces secrets.") print("Go to Settings -> Repository secrets -> New secret") print("Name: HF_TOKEN, Value: your_huggingface_token") # For now, we'll try to continue without authentication hf_token = None # Login to Hugging Face Hub for gated model access if hf_token: try: login(token=hf_token, add_to_git_credential=False) print("Successfully logged in to Hugging Face Hub") except Exception as e: print(f"Warning: Could not login to HF Hub: {e}") # Initialize tokenizer with HF token for gated model access if tk is None: try: tk = Tokenizer(hf_token=hf_token) tk = tk.ready_tokenizer() print("Tokenizer initialized successfully") except Exception as e: print(f"Error initializing tokenizer: {e}") print("This might be due to missing HF_TOKEN or lack of access to gated models.") print("The app will try to use a fallback tokenizer.") # Don't raise the error, let the tokenizer handle fallback try: tk = Tokenizer(hf_token=None) # Force fallback tk = tk.ready_tokenizer() print("Fallback tokenizer initialized successfully") except Exception as fallback_error: print(f"Fallback tokenizer also failed: {fallback_error}") raise fallback_error # Initialize the global tokenizer in model.py initialize_tokenizer(hf_token=hf_token) def load_model(model_path, device, model_args): """Load model from checkpoint""" model = DeepSeekV3( embeddings_dims=model_args.embeddings_dims, block_size=model_args.block_size, vocab_size=model_args.vocab_size, dropout=model_args.dropout, device=device ) if os.path.exists(model_path): checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint) model.eval() print(f"Model loaded from {model_path}") else: print(f"Checkpoint {model_path} not found. Using randomly initialized model.") return model @spaces.GPU(duration=120) def generate_text(prompt, model_choice, max_length, temperature, top_k): """Generate text using the selected model and top-k sampling""" global tk, model_args device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Load the selected model model_path = model_paths.get(model_choice, "./checkpoint_2000.pt") model = load_model(model_path, device, model_args) model = model.to(device) try: generated_text = topk_sampling( model=model, prompt=prompt, device=device, max_length=max_length, top_k=top_k, temperature=temperature, tokenizer=tk ) return generated_text except Exception as e: return f"Error generating text: {str(e)}" def create_interface(): """Create the Gradio interface""" global tk, model_args # Initialize the app initialize_app() with gr.Blocks(title="StoryKimi Text Generator", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🚀 StoryKimi Text Generator") gr.Markdown("Generate text using the Kimi K2 inspired StoryKimi model with ZeroGPU support.") gr.Markdown("⚡ **Powered by ZeroGPU** - Dynamic GPU allocation for efficient inference") with gr.Row(): with gr.Column(scale=2): prompt_input = gr.Textbox( label="Input Prompt", placeholder="Enter your prompt here...", lines=3, value="Once upon a time there lived a baby deer named Bambi." ) with gr.Row(): model_dropdown = gr.Dropdown( choices=list(model_paths.keys()), label="Model Checkpoint", value="Checkpoint 2000" ) with gr.Row(): max_length_slider = gr.Slider( minimum=10, maximum=128, value=50, step=10, label="Max Length" ) temperature_slider = gr.Slider( minimum=0.1, maximum=2.0, value=0.9, step=0.1, label="Temperature" ) with gr.Row(): top_k_slider = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Top-k" ) generate_btn = gr.Button("đŸŽ¯ Generate Text", variant="primary", size="lg") with gr.Column(scale=3): output_text = gr.Textbox( label="Generated Text", lines=15, interactive=False ) with gr.Row(): clear_btn = gr.Button("đŸ—‘ī¸ Clear", variant="secondary") # Event handlers generate_btn.click( fn=generate_text, inputs=[ prompt_input, model_dropdown, max_length_slider, temperature_slider, top_k_slider ], outputs=output_text ) clear_btn.click( fn=lambda: ("", ""), outputs=[prompt_input, output_text] ) # Model information gr.Markdown("## â„šī¸ Model Information") gr.Markdown(""" - **Model Architecture**: Kimi K2 inspired (StoryKimi) - **ZeroGPU**: Dynamic GPU allocation with H200 slice (70GB VRAM) - **GPU Duration**: 120 seconds maximum per generation - **Deployment**: Hugging Face Spaces with automatic scaling """) gr.Markdown("## 🚀 Features") gr.Markdown(""" - **Top-k Sampling**: Control randomness with top-k token selection - **Temperature Control**: Adjust creativity vs coherence - **Variable Length**: Generate 10-128 tokens - **Real-time Generation**: Powered by ZeroGPU infrastructure """) return demo if __name__ == "__main__": # Create and launch the interface demo = create_interface() demo.launch()