File size: 3,097 Bytes
fd3f5ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import threading
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# --- Config ---
MODEL_ID = os.getenv("MODEL_ID", "WeiboAI/VibeThinker-1.5B")
SYSTEM_PROMPT = os.getenv(
    "SYSTEM_PROMPT",
    "You are a concise solver. Return a single short answer. Do not explain."
)
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2"))
TOP_P = float(os.getenv("TOP_P", "0.9"))
MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "256"))

# --- Load ---
print(f"Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
# Use CPU on ZeroGPU; float32 avoids CPU bf16 issues on some wheels
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float32
).to("cpu").eval()
print("Model loaded.")

def build_prompt(message, history):
    """Use the model's chat template if available."""
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    if history:
        for user_msg, assistant_msg in history:
            if user_msg:
                messages.append({"role": "user", "content": str(user_msg)})
            if assistant_msg:
                messages.append({"role": "assistant", "content": str(assistant_msg)})
    messages.append({"role": "user", "content": str(message or '')})

    try:
        prompt = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
    except Exception:
        # Fallback (shouldn’t hit for Qwen-style models)
        prompt = f"[SYSTEM]\n{SYSTEM_PROMPT}\n[USER]\n{message}\n[ASSISTANT]\n"
    return prompt

def chat_fn(message, history):
    """Streamed generation compatible with gr.ChatInterface (yields partials)."""
    prompt = build_prompt(message, history)
    inputs = tokenizer([prompt], return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    streamer = TextIteratorStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )
    gen_kwargs = dict(
        **inputs,
        streamer=streamer,
        do_sample=True,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        max_new_tokens=MAX_NEW_TOKENS,
        repetition_penalty=1.05,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
    )

    thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    partial = ""
    for new_text in streamer:
        partial += new_text
        # Optional hard stop: if user wants one-liners, cut after first newline.
        # idx = partial.find("\n")
        # if idx != -1:
        #     yield partial[:idx].strip()
        #     return
        yield partial.strip()

demo = gr.ChatInterface(
    fn=chat_fn,
    title="VibeThinker-1.5B Chat (CPU)",
    description="WeiboAI/VibeThinker-1.5B • Simple streaming chat on CPU. "
                "Set MODEL_ID/TEMPERATURE/TOP_P/MAX_NEW_TOKENS in Space Variables."
)

if __name__ == "__main__":
    demo.queue().launch()