import os, time, threading import gradio as gr import torch, spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer MODEL_ID = "WeiboAI/VibeThinker-1.5B" SYSTEM_PROMPT = "You are a concise solver. Give one clear final answer." MAX_INPUT_TOKENS = 384 MAX_NEW_TOKENS = 128 TEMPERATURE = 0.4 TOP_P = 0.9 NO_TOKEN_TIMEOUT = 8 # seconds with no new token -> stop print(f"⏳ Loading {MODEL_ID} …", flush=True) tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, trust_remote_code=True, low_cpu_mem_usage=True, dtype=torch.bfloat16, # <- use dtype (not torch_dtype) device_map="auto", ).eval() print("✅ Model ready.", flush=True) def _apply_template(messages): return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) def _clip_inputs(prompt_text, max_tokens): ids = tok([prompt_text], return_tensors="pt") if ids["input_ids"].shape[-1] > max_tokens: ids = {k: v[:, -max_tokens:] for k, v in ids.items()} return {k: v.to(model.device) for k, v in ids.items()} @spaces.GPU(duration=90) def respond(message, history): history = history or [] msgs = [{"role": "system", "content": SYSTEM_PROMPT}, *history, {"role": "user", "content": str(message)}] prompt = _apply_template(msgs) inputs = _clip_inputs(prompt, MAX_INPUT_TOKENS) streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( **inputs, streamer=streamer, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P, repetition_penalty=1.18, max_new_tokens=MAX_NEW_TOKENS, pad_token_id=tok.eos_token_id, use_cache=True, ) th = threading.Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) th.start() assistant = {"role": "assistant", "content": ""} out = list(history) + [assistant] last_token_time = time.time() last_yield = 0 for chunk in streamer: assistant["content"] += chunk last_token_time = time.time() # heartbeat every ~4s so frontend never stalls now = time.time() if now - last_yield >= 4: yield out last_yield = now # wait briefly for tail tokens; abort if none arrive while th.is_alive() and (time.time() - last_token_time) < NO_TOKEN_TIMEOUT: time.sleep(0.5) yield out if th.is_alive(): assistant["content"] += f"\n\n(Stopped: no tokens for {NO_TOKEN_TIMEOUT}s)" yield out with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("## 💡 VibeThinker-1.5B — ZeroGPU slice (stable streaming)") chat = gr.Chatbot(type="messages", height=520) box = gr.Textbox(placeholder="Ask a question…") send = gr.Button("Send", variant="primary") def pipeline(msg, hist): for hist in respond(msg, hist): yield "", hist box.submit(pipeline, [box, chat], [box, chat]) send.click(pipeline, [box, chat], [box, chat]) if __name__ == "__main__": # Gradio 4.x: queue() has no concurrency_count; keep max_size if desired demo.queue(max_size=16).launch()