Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import threading | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| # --- Config --- | |
| MODEL_ID = os.getenv("MODEL_ID", "WeiboAI/VibeThinker-1.5B") | |
| SYSTEM_PROMPT = os.getenv( | |
| "SYSTEM_PROMPT", | |
| "You are a concise solver. Return a single short answer. Do not explain." | |
| ) | |
| TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2")) | |
| TOP_P = float(os.getenv("TOP_P", "0.9")) | |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256")) | |
| # --- Load --- | |
| print(f"Loading model: {MODEL_ID}") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| # Use CPU on ZeroGPU; float32 avoids CPU bf16 issues on some wheels | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch.float32 | |
| ).to("cpu").eval() | |
| print("Model loaded.") | |
| def build_prompt(message, history): | |
| """Use the model's chat template if available.""" | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| if history: | |
| for user_msg, assistant_msg in history: | |
| if user_msg: | |
| messages.append({"role": "user", "content": str(user_msg)}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": str(assistant_msg)}) | |
| messages.append({"role": "user", "content": str(message or '')}) | |
| try: | |
| prompt = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| except Exception: | |
| # Fallback (shouldn’t hit for Qwen-style models) | |
| prompt = f"[SYSTEM]\n{SYSTEM_PROMPT}\n[USER]\n{message}\n[ASSISTANT]\n" | |
| return prompt | |
| def chat_fn(message, history): | |
| """Streamed generation compatible with gr.ChatInterface (yields partials).""" | |
| prompt = build_prompt(message, history) | |
| inputs = tokenizer([prompt], return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| streamer = TextIteratorStreamer( | |
| tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| gen_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| do_sample=True, | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| repetition_penalty=1.05, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| partial = "" | |
| for new_text in streamer: | |
| partial += new_text | |
| # Optional hard stop: if user wants one-liners, cut after first newline. | |
| # idx = partial.find("\n") | |
| # if idx != -1: | |
| # yield partial[:idx].strip() | |
| # return | |
| yield partial.strip() | |
| demo = gr.ChatInterface( | |
| fn=chat_fn, | |
| title="VibeThinker-1.5B Chat (CPU)", | |
| description="WeiboAI/VibeThinker-1.5B • Simple streaming chat on CPU. " | |
| "Set MODEL_ID/TEMPERATURE/TOP_P/MAX_NEW_TOKENS in Space Variables." | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |