import os import threading import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer # --- Config --- MODEL_ID = os.getenv("MODEL_ID", "WeiboAI/VibeThinker-1.5B") SYSTEM_PROMPT = os.getenv( "SYSTEM_PROMPT", "You are a concise solver. Return a single short answer. Do not explain." ) TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2")) TOP_P = float(os.getenv("TOP_P", "0.9")) MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256")) # --- Load --- print(f"Loading model: {MODEL_ID}") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) # Use CPU on ZeroGPU; float32 avoids CPU bf16 issues on some wheels model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=torch.float32 ).to("cpu").eval() print("Model loaded.") def build_prompt(message, history): """Use the model's chat template if available.""" messages = [{"role": "system", "content": SYSTEM_PROMPT}] if history: for user_msg, assistant_msg in history: if user_msg: messages.append({"role": "user", "content": str(user_msg)}) if assistant_msg: messages.append({"role": "assistant", "content": str(assistant_msg)}) messages.append({"role": "user", "content": str(message or '')}) try: prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) except Exception: # Fallback (shouldn’t hit for Qwen-style models) prompt = f"[SYSTEM]\n{SYSTEM_PROMPT}\n[USER]\n{message}\n[ASSISTANT]\n" return prompt def chat_fn(message, history): """Streamed generation compatible with gr.ChatInterface (yields partials).""" prompt = build_prompt(message, history) inputs = tokenizer([prompt], return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) gen_kwargs = dict( **inputs, streamer=streamer, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P, max_new_tokens=MAX_NEW_TOKENS, repetition_penalty=1.05, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, ) thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) thread.start() partial = "" for new_text in streamer: partial += new_text # Optional hard stop: if user wants one-liners, cut after first newline. # idx = partial.find("\n") # if idx != -1: # yield partial[:idx].strip() # return yield partial.strip() demo = gr.ChatInterface( fn=chat_fn, title="VibeThinker-1.5B Chat (CPU)", description="WeiboAI/VibeThinker-1.5B • Simple streaming chat on CPU. " "Set MODEL_ID/TEMPERATURE/TOP_P/MAX_NEW_TOKENS in Space Variables." ) if __name__ == "__main__": demo.queue().launch()