Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces # HF Spaces ZeroGPU decorator - only available in HF Spaces environment | |
| import torch | |
| import torch.nn.functional as F | |
| import os | |
| import sys | |
| from huggingface_hub import login | |
| from config import ModelArgs, get_args | |
| from model import DeepSeekV3, initialize_tokenizer | |
| from tokenizer import Tokenizer | |
| from inference import topk_sampling | |
| # Global variables | |
| tk = None | |
| model = None | |
| model_args = None | |
| # Model paths - using the checkpoint in the HF Space | |
| model_paths = { | |
| "Checkpoint 2000": "./checkpoint_2000.pt", | |
| } | |
| def initialize_app(): | |
| """Initialize the app with tokenizer and model args""" | |
| global tk, model_args | |
| # Initialize model args | |
| model_args = ModelArgs() | |
| # Get HF token from environment variables (set in HF Spaces secrets) | |
| hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_HUB_TOKEN') | |
| if not hf_token: | |
| print("Warning: No HF_TOKEN found in environment variables.") | |
| print("Please set HF_TOKEN in your Hugging Face Spaces secrets.") | |
| print("Go to Settings -> Repository secrets -> New secret") | |
| print("Name: HF_TOKEN, Value: your_huggingface_token") | |
| # For now, we'll try to continue without authentication | |
| hf_token = None | |
| # Login to Hugging Face Hub for gated model access | |
| if hf_token: | |
| try: | |
| login(token=hf_token, add_to_git_credential=False) | |
| print("Successfully logged in to Hugging Face Hub") | |
| except Exception as e: | |
| print(f"Warning: Could not login to HF Hub: {e}") | |
| # Initialize tokenizer with HF token for gated model access | |
| if tk is None: | |
| try: | |
| tk = Tokenizer(hf_token=hf_token) | |
| tk = tk.ready_tokenizer() | |
| print("Tokenizer initialized successfully") | |
| except Exception as e: | |
| print(f"Error initializing tokenizer: {e}") | |
| print("This might be due to missing HF_TOKEN or lack of access to gated models.") | |
| print("The app will try to use a fallback tokenizer.") | |
| # Don't raise the error, let the tokenizer handle fallback | |
| try: | |
| tk = Tokenizer(hf_token=None) # Force fallback | |
| tk = tk.ready_tokenizer() | |
| print("Fallback tokenizer initialized successfully") | |
| except Exception as fallback_error: | |
| print(f"Fallback tokenizer also failed: {fallback_error}") | |
| raise fallback_error | |
| # Initialize the global tokenizer in model.py | |
| initialize_tokenizer(hf_token=hf_token) | |
| def load_model(model_path, device, model_args): | |
| """Load model from checkpoint""" | |
| model = DeepSeekV3( | |
| embeddings_dims=model_args.embeddings_dims, | |
| block_size=model_args.block_size, | |
| vocab_size=model_args.vocab_size, | |
| dropout=model_args.dropout, | |
| device=device | |
| ) | |
| if os.path.exists(model_path): | |
| checkpoint = torch.load(model_path, map_location=device) | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| print(f"Model loaded from {model_path}") | |
| else: | |
| print(f"Checkpoint {model_path} not found. Using randomly initialized model.") | |
| return model | |
| def generate_text(prompt, model_choice, max_length, temperature, top_k): | |
| """Generate text using the selected model and top-k sampling""" | |
| global tk, model_args | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Load the selected model | |
| model_path = model_paths.get(model_choice, "./checkpoint_2000.pt") | |
| model = load_model(model_path, device, model_args) | |
| model = model.to(device) | |
| try: | |
| generated_text = topk_sampling( | |
| model=model, | |
| prompt=prompt, | |
| device=device, | |
| max_length=max_length, | |
| top_k=top_k, | |
| temperature=temperature, | |
| tokenizer=tk | |
| ) | |
| return generated_text | |
| except Exception as e: | |
| return f"Error generating text: {str(e)}" | |
| def create_interface(): | |
| """Create the Gradio interface""" | |
| global tk, model_args | |
| # Initialize the app | |
| initialize_app() | |
| with gr.Blocks(title="StoryKimi Text Generator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π StoryKimi Text Generator") | |
| gr.Markdown("Generate text using the Kimi K2 inspired StoryKimi model with ZeroGPU support.") | |
| gr.Markdown("β‘ **Powered by ZeroGPU** - Dynamic GPU allocation for efficient inference") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_input = gr.Textbox( | |
| label="Input Prompt", | |
| placeholder="Enter your prompt here...", | |
| lines=3, | |
| value="Once upon a time there lived a baby deer named Bambi." | |
| ) | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(model_paths.keys()), | |
| label="Model Checkpoint", | |
| value="Checkpoint 2000" | |
| ) | |
| with gr.Row(): | |
| max_length_slider = gr.Slider( | |
| minimum=10, | |
| maximum=128, | |
| value=50, | |
| step=10, | |
| label="Max Length" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.9, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| with gr.Row(): | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Top-k" | |
| ) | |
| generate_btn = gr.Button("π― Generate Text", variant="primary", size="lg") | |
| with gr.Column(scale=3): | |
| output_text = gr.Textbox( | |
| label="Generated Text", | |
| lines=15, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary") | |
| # Event handlers | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=[ | |
| prompt_input, | |
| model_dropdown, | |
| max_length_slider, | |
| temperature_slider, | |
| top_k_slider | |
| ], | |
| outputs=output_text | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", ""), | |
| outputs=[prompt_input, output_text] | |
| ) | |
| # Model information | |
| gr.Markdown("## βΉοΈ Model Information") | |
| gr.Markdown(""" | |
| - **Model Architecture**: Kimi K2 inspired (StoryKimi) | |
| - **ZeroGPU**: Dynamic GPU allocation with H200 slice (70GB VRAM) | |
| - **GPU Duration**: 120 seconds maximum per generation | |
| - **Deployment**: Hugging Face Spaces with automatic scaling | |
| """) | |
| gr.Markdown("## π Features") | |
| gr.Markdown(""" | |
| - **Top-k Sampling**: Control randomness with top-k token selection | |
| - **Temperature Control**: Adjust creativity vs coherence | |
| - **Variable Length**: Generate 10-128 tokens | |
| - **Real-time Generation**: Powered by ZeroGPU infrastructure | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| # Create and launch the interface | |
| demo = create_interface() | |
| demo.launch() | |