StoryKimi-Zero / app.py
yuvraj-singh-9886's picture
Add liger-kernel dependency and update model files
c156c1f
import gradio as gr
import spaces # HF Spaces ZeroGPU decorator - only available in HF Spaces environment
import torch
import torch.nn.functional as F
import os
import sys
from huggingface_hub import login
from config import ModelArgs, get_args
from model import DeepSeekV3, initialize_tokenizer
from tokenizer import Tokenizer
from inference import topk_sampling
# Global variables
tk = None
model = None
model_args = None
# Model paths - using the checkpoint in the HF Space
model_paths = {
"Checkpoint 2000": "./checkpoint_2000.pt",
}
def initialize_app():
"""Initialize the app with tokenizer and model args"""
global tk, model_args
# Initialize model args
model_args = ModelArgs()
# Get HF token from environment variables (set in HF Spaces secrets)
hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGINGFACE_HUB_TOKEN')
if not hf_token:
print("Warning: No HF_TOKEN found in environment variables.")
print("Please set HF_TOKEN in your Hugging Face Spaces secrets.")
print("Go to Settings -> Repository secrets -> New secret")
print("Name: HF_TOKEN, Value: your_huggingface_token")
# For now, we'll try to continue without authentication
hf_token = None
# Login to Hugging Face Hub for gated model access
if hf_token:
try:
login(token=hf_token, add_to_git_credential=False)
print("Successfully logged in to Hugging Face Hub")
except Exception as e:
print(f"Warning: Could not login to HF Hub: {e}")
# Initialize tokenizer with HF token for gated model access
if tk is None:
try:
tk = Tokenizer(hf_token=hf_token)
tk = tk.ready_tokenizer()
print("Tokenizer initialized successfully")
except Exception as e:
print(f"Error initializing tokenizer: {e}")
print("This might be due to missing HF_TOKEN or lack of access to gated models.")
print("The app will try to use a fallback tokenizer.")
# Don't raise the error, let the tokenizer handle fallback
try:
tk = Tokenizer(hf_token=None) # Force fallback
tk = tk.ready_tokenizer()
print("Fallback tokenizer initialized successfully")
except Exception as fallback_error:
print(f"Fallback tokenizer also failed: {fallback_error}")
raise fallback_error
# Initialize the global tokenizer in model.py
initialize_tokenizer(hf_token=hf_token)
def load_model(model_path, device, model_args):
"""Load model from checkpoint"""
model = DeepSeekV3(
embeddings_dims=model_args.embeddings_dims,
block_size=model_args.block_size,
vocab_size=model_args.vocab_size,
dropout=model_args.dropout,
device=device
)
if os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint)
model.eval()
print(f"Model loaded from {model_path}")
else:
print(f"Checkpoint {model_path} not found. Using randomly initialized model.")
return model
@spaces.GPU(duration=120)
def generate_text(prompt, model_choice, max_length, temperature, top_k):
"""Generate text using the selected model and top-k sampling"""
global tk, model_args
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load the selected model
model_path = model_paths.get(model_choice, "./checkpoint_2000.pt")
model = load_model(model_path, device, model_args)
model = model.to(device)
try:
generated_text = topk_sampling(
model=model,
prompt=prompt,
device=device,
max_length=max_length,
top_k=top_k,
temperature=temperature,
tokenizer=tk
)
return generated_text
except Exception as e:
return f"Error generating text: {str(e)}"
def create_interface():
"""Create the Gradio interface"""
global tk, model_args
# Initialize the app
initialize_app()
with gr.Blocks(title="StoryKimi Text Generator", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸš€ StoryKimi Text Generator")
gr.Markdown("Generate text using the Kimi K2 inspired StoryKimi model with ZeroGPU support.")
gr.Markdown("⚑ **Powered by ZeroGPU** - Dynamic GPU allocation for efficient inference")
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="Input Prompt",
placeholder="Enter your prompt here...",
lines=3,
value="Once upon a time there lived a baby deer named Bambi."
)
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(model_paths.keys()),
label="Model Checkpoint",
value="Checkpoint 2000"
)
with gr.Row():
max_length_slider = gr.Slider(
minimum=10,
maximum=128,
value=50,
step=10,
label="Max Length"
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.9,
step=0.1,
label="Temperature"
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top-k"
)
generate_btn = gr.Button("🎯 Generate Text", variant="primary", size="lg")
with gr.Column(scale=3):
output_text = gr.Textbox(
label="Generated Text",
lines=15,
interactive=False
)
with gr.Row():
clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
# Event handlers
generate_btn.click(
fn=generate_text,
inputs=[
prompt_input,
model_dropdown,
max_length_slider,
temperature_slider,
top_k_slider
],
outputs=output_text
)
clear_btn.click(
fn=lambda: ("", ""),
outputs=[prompt_input, output_text]
)
# Model information
gr.Markdown("## ℹ️ Model Information")
gr.Markdown("""
- **Model Architecture**: Kimi K2 inspired (StoryKimi)
- **ZeroGPU**: Dynamic GPU allocation with H200 slice (70GB VRAM)
- **GPU Duration**: 120 seconds maximum per generation
- **Deployment**: Hugging Face Spaces with automatic scaling
""")
gr.Markdown("## πŸš€ Features")
gr.Markdown("""
- **Top-k Sampling**: Control randomness with top-k token selection
- **Temperature Control**: Adjust creativity vs coherence
- **Variable Length**: Generate 10-128 tokens
- **Real-time Generation**: Powered by ZeroGPU infrastructure
""")
return demo
if __name__ == "__main__":
# Create and launch the interface
demo = create_interface()
demo.launch()