VibeThinker / app.py
VladBoyko's picture
Update app.py
bfb609d verified
raw
history blame
20 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import time
class VibeThinkerModel:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.load_model()
def load_model(self):
"""Load VibeThinker model with transformers"""
try:
print("πŸ”„ Loading VibeThinker-1.5B with transformers...")
self.tokenizer = AutoTokenizer.from_pretrained(
"WeiboAI/VibeThinker-1.5B",
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
"WeiboAI/VibeThinker-1.5B",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
print(f"βœ… Model loaded successfully on {self.device}")
print(f"πŸ’Ύ Model memory: ~{self.model.get_memory_footprint() / 1e9:.2f} GB")
except Exception as e:
print(f"❌ Error loading model: {e}")
raise
def generate_response(self, prompt, temperature=0.6, max_new_tokens=8192, max_thinking_tokens=4096):
"""
Generate response with thinking length control
Args:
prompt: Input prompt
temperature: Sampling temperature
max_new_tokens: Maximum new tokens to generate
max_thinking_tokens: Hint for reasoning depth (used in prompt)
"""
if not self.model or not self.tokenizer:
return "Model not loaded!", 0, 0, 0
try:
start_time = time.time()
# Format prompt for competitive coding
formatted_prompt = f"""<|im_start|>system
You are a competitive programming expert. Provide clear, concise solutions to coding problems.
Format your response as:
1. Brief analysis (2-3 sentences)
2. Solution approach
3. Implementation code
4. Test cases
Keep reasoning under {max_thinking_tokens} tokens. Be direct and avoid repetition.<|im_end|>
<|im_start|>user
{prompt}<|im_end|>
<|im_start|>assistant
"""
# Tokenize input
inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
prompt_length = inputs.input_ids.shape[1]
# Generate with appropriate parameters
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=0.95,
top_k=50,
do_sample=True,
repetition_penalty=1.1,
pad_token_id=self.tokenizer.eos_token_id,
)
# Decode output
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
if "<|im_start|>assistant" in full_output:
generated_text = full_output.split("<|im_start|>assistant")[-1].strip()
else:
generated_text = full_output[len(formatted_prompt):].strip()
# Check for loops and truncate if needed
if self._detect_loop(generated_text):
generated_text = self._truncate_loop(generated_text)
generated_text += "\n\n⚠️ *[Repetitive content detected and truncated]*"
generation_time = time.time() - start_time
# Calculate token counts
completion_length = outputs.shape[1] - prompt_length
return generated_text, prompt_length, completion_length, generation_time
except Exception as e:
return f"Error during generation: {str(e)}", 0, 0, 0
def _detect_loop(self, text):
"""Detect if text contains repetitive loops"""
words = text.split()
if len(words) < 20:
return False
# Check if same phrase repeats 3+ times
for length in [10, 15, 20]:
if len(words) < length * 3:
continue
for i in range(len(words) - length * 3):
phrase = ' '.join(words[i:i+length])
rest = ' '.join(words[i+length:])
if rest.count(phrase) >= 2:
return True
return False
def _truncate_loop(self, text):
"""Truncate text at the start of detected loop"""
words = text.split()
for length in [10, 15, 20]:
if len(words) < length * 2:
continue
for i in range(len(words) - length * 2):
phrase = ' '.join(words[i:i+length])
rest_start = i + length
rest = ' '.join(words[rest_start:])
if phrase in rest:
return ' '.join(words[:rest_start])
return text
def parse_model_output(text):
"""
Parse model output to separate thinking and final answer
ONLY extract code from the final answer section, not from thinking
"""
loop_warning = ""
if "[Repetitive content detected and truncated]" in text:
loop_warning = "\n\n⚠️ **Note**: Repetitive content was detected and removed"
text = text.replace("⚠️ *[Repetitive content detected and truncated]*", "")
# Try to find explicit thinking delimiters
thinking_patterns = [
r'<think>(.*?)</think>',
r'<thinking>(.*?)</thinking>',
]
thinking_content = ""
answer_content = text
for pattern in thinking_patterns:
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
if match:
thinking_content = match.group(1).strip()
answer_content = re.sub(pattern, '', text, flags=re.DOTALL | re.IGNORECASE).strip()
break
# If no explicit thinking tags, try to detect reasoning section
if not thinking_content:
split_markers = [
r'(.*?)(?=\n\n(?:Solution|Here\'s|Implementation|Code|Final).*?:)',
r'(.*?)(?=\n\n```)', # Before first code block
]
for pattern in split_markers:
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
if match:
potential_thinking = match.group(1).strip()
if len(potential_thinking) > 100:
thinking_lower = potential_thinking.lower()
if any(word in thinking_lower for word in ['step', 'approach', 'idea', 'first', 'we can', 'let\'s']):
thinking_content = potential_thinking
answer_content = text[len(potential_thinking):].strip()
break
# Extract code blocks ONLY from answer_content
code_pattern = r'```(\w+)?\n(.*?)```'
code_blocks = re.findall(code_pattern, answer_content, re.DOTALL)
# Extract final answer
answer_match = re.search(r'\\boxed\{([^}]+)\}', answer_content)
if answer_match:
final_answer = f"**Final Answer:** {answer_match.group(1)}"
else:
final_answer = answer_content
final_answer += loop_warning
return thinking_content, final_answer, code_blocks
def format_output_html(thinking, answer, code_blocks, prompt_tokens, completion_tokens, generation_time):
"""Format output as styled HTML"""
total_tokens = prompt_tokens + completion_tokens
thinking_tokens_est = len(thinking.split()) * 1.3 if thinking else 0
tokens_per_sec = completion_tokens / generation_time if generation_time > 0 else 0
# Build thinking section HTML - plain text only
thinking_html = ""
if thinking:
thinking_escaped = thinking.replace('<', '&lt;').replace('>', '&gt;')
thinking_html = f"""
<details style="background: #f8f9fa; border: 2px solid #e9ecef; border-radius: 12px; padding: 20px; margin-bottom: 24px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
<summary style="cursor: pointer; font-weight: 600; font-size: 16px; color: #495057; user-select: none; display: flex; align-items: center; gap: 8px;">
<span style="font-size: 20px;">🧠</span>
<span>Reasoning Process (~{int(thinking_tokens_est):,} tokens)</span>
<span style="margin-left: auto; font-size: 12px; color: #6c757d;">Click to expand/collapse</span>
</summary>
<div style="margin-top: 16px; padding-top: 16px; border-top: 1px solid #dee2e6; color: #212529; line-height: 1.7; white-space: pre-wrap; font-size: 14px; font-family: 'SF Mono', Monaco, Consolas, monospace; background: #ffffff; padding: 16px; border-radius: 8px;">
{thinking_escaped}
</div>
</details>
"""
# Build code blocks HTML
code_html = ""
if code_blocks:
code_blocks_html = ""
for idx, (lang, code) in enumerate(code_blocks):
lang_display = lang if lang else "code"
code_id = f"code_{idx}"
code_clean = code.strip()
code_blocks_html += f"""
<div style="margin-bottom: 16px; background: #1e1e1e; border-radius: 12px; overflow: hidden; box-shadow: 0 4px 6px rgba(0,0,0,0.1);">
<div style="background: #2d2d2d; padding: 12px 20px; color: #ffffff; font-weight: 600; font-size: 13px; display: flex; justify-content: space-between; align-items: center; border-bottom: 1px solid #3d3d3d;">
<span>{lang_display}</span>
<div style="display: flex; gap: 8px;">
<button onclick="navigator.clipboard.writeText(document.getElementById('{code_id}').textContent)"
style="background: #4CAF50; color: white; border: none; padding: 6px 14px; border-radius: 6px; cursor: pointer; font-size: 12px; font-weight: 500; transition: background 0.2s;"
onmouseover="this.style.background='#45a049'"
onmouseout="this.style.background='#4CAF50'">
πŸ“‹ Copy
</button>
<button onclick="downloadCode(document.getElementById('{code_id}').textContent, '{lang_display}')"
style="background: #2196F3; color: white; border: none; padding: 6px 14px; border-radius: 6px; cursor: pointer; font-size: 12px; font-weight: 500; transition: background 0.2s;"
onmouseover="this.style.background='#0b7dda'"
onmouseout="this.style.background='#2196F3'">
πŸ’Ύ Download
</button>
</div>
</div>
<pre style="margin: 0; padding: 20px; color: #d4d4d4; overflow-x: auto; font-family: 'SF Mono', Monaco, Consolas, monospace; font-size: 14px; line-height: 1.6;"><code id="{code_id}">{code_clean}</code></pre>
</div>
"""
code_html = f"""
<div style="margin-top: 24px;">
<h3 style="color: #1a1a1a; font-size: 18px; font-weight: 600; margin-bottom: 16px; display: flex; align-items: center; gap: 8px;">
<span style="font-size: 22px;">πŸ’»</span> Code
</h3>
{code_blocks_html}
</div>
<script>
function downloadCode(code, lang) {{
const extensions = {{
'python': 'py', 'javascript': 'js', 'java': 'java',
'cpp': 'cpp', 'c': 'c', 'html': 'html', 'css': 'css',
'typescript': 'ts', 'rust': 'rs', 'go': 'go',
}};
const ext = extensions[lang.toLowerCase()] || 'txt';
const filename = `solution.${{ext}}`;
const blob = new Blob([code], {{ type: 'text/plain' }});
const url = window.URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = filename;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
window.URL.revokeObjectURL(url);
}}
</script>
"""
html = f"""
<div style="font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; max-width: 100%; margin: 0 auto; background: #ffffff; color: #1a1a1a;">
<!-- Stats -->
<div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 12px; margin-bottom: 24px; color: white; box-shadow: 0 4px 6px rgba(0,0,0,0.1);">
<h3 style="margin: 0 0 12px 0; font-size: 18px; font-weight: 600;">πŸ“Š Generation Stats</h3>
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(140px, 1fr)); gap: 12px; font-size: 14px;">
<div style="background: rgba(255,255,255,0.2); padding: 12px; border-radius: 8px;">
<div style="opacity: 0.9; font-size: 12px; margin-bottom: 4px;">Time</div>
<div style="font-size: 20px; font-weight: bold;">{generation_time:.1f}s</div>
</div>
<div style="background: rgba(255,255,255,0.2); padding: 12px; border-radius: 8px;">
<div style="opacity: 0.9; font-size: 12px; margin-bottom: 4px;">Speed</div>
<div style="font-size: 20px; font-weight: bold;">{tokens_per_sec:.1f} t/s</div>
</div>
<div style="background: rgba(255,255,255,0.2); padding: 12px; border-radius: 8px;">
<div style="opacity: 0.9; font-size: 12px; margin-bottom: 4px;">Prompt</div>
<div style="font-size: 20px; font-weight: bold;">{prompt_tokens:,}</div>
</div>
<div style="background: rgba(255,255,255,0.2); padding: 12px; border-radius: 8px;">
<div style="opacity: 0.9; font-size: 12px; margin-bottom: 4px;">Output</div>
<div style="font-size: 20px; font-weight: bold;">{completion_tokens:,}</div>
</div>
<div style="background: rgba(255,255,255,0.2); padding: 12px; border-radius: 8px;">
<div style="opacity: 0.9; font-size: 12px; margin-bottom: 4px;">Thinking</div>
<div style="font-size: 20px; font-weight: bold;">~{int(thinking_tokens_est):,}</div>
</div>
<div style="background: rgba(255,255,255,0.2); padding: 12px; border-radius: 8px;">
<div style="opacity: 0.9; font-size: 12px; margin-bottom: 4px;">Total</div>
<div style="font-size: 20px; font-weight: bold;">{total_tokens:,}</div>
</div>
</div>
</div>
{thinking_html}
<!-- Answer -->
<div style="background: #ffffff; border: 2px solid #28a745; border-radius: 12px; padding: 24px; margin-bottom: 24px; box-shadow: 0 2px 4px rgba(40,167,69,0.1);">
<h3 style="margin: 0 0 16px 0; color: #28a745; font-size: 18px; font-weight: 600; display: flex; align-items: center; gap: 8px;">
<span style="font-size: 22px;">βœ…</span> Final Solution
</h3>
<div style="color: #212529; line-height: 1.8; font-size: 15px; white-space: pre-wrap;">
{answer}
</div>
</div>
{code_html}
</div>
"""
return html
# Initialize model
print("πŸ”„ Initializing VibeThinker-1.5B...")
vibe_model = VibeThinkerModel()
def generate_solution(prompt, temperature=0.6, max_tokens=8192, max_thinking_tokens=4096, progress=gr.Progress()):
"""Generate and format solution with progress tracking"""
if not prompt.strip():
return "<p style='color: #dc3545; font-size: 16px; padding: 20px;'>⚠️ Please enter a problem to solve.</p>"
progress(0, desc="πŸ”„ Initializing...")
progress(0.2, desc="🧠 Generating solution...")
response, prompt_tokens, completion_tokens, gen_time = vibe_model.generate_response(
prompt,
temperature=temperature,
max_new_tokens=max_tokens,
max_thinking_tokens=max_thinking_tokens
)
progress(0.8, desc="πŸ“ Formatting output...")
thinking, answer, code_blocks = parse_model_output(response)
html_output = format_output_html(thinking, answer, code_blocks, prompt_tokens, completion_tokens, gen_time)
progress(1.0, desc="βœ… Complete!")
return html_output
# Create Gradio interface
with gr.Blocks(
theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="purple"),
css=".gradio-container { max-width: 1400px !important; }"
) as demo:
gr.Markdown("""
# 🧠 VibeThinker-1.5B Competitive Coding Assistant
**Optimized for**: Competitive programming (LeetCode, Codeforces, AtCoder) and algorithm challenges
🎯 **Best for**: Python algorithmic problems with clear input/output specifications
⚠️ **Note**: This model is specialized for competitive programming, not general software development
""")
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(
label="πŸ’­ Your Coding Problem",
placeholder="Example: Write a Python function to find the longest palindromic substring in a given string. Include test cases.",
lines=8
)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
temperature_slider = gr.Slider(
minimum=0.0, maximum=1.0, value=0.6, step=0.1,
label="🌑️ Temperature (0.6 recommended)"
)
max_tokens_slider = gr.Slider(
minimum=1024, maximum=16384, value=8192, step=1024,
label="πŸ“ Max New Tokens"
)
max_thinking_slider = gr.Slider(
minimum=512, maximum=8192, value=3072, step=512,
label="🧠 Max Thinking Tokens (hint for prompt)"
)
gr.Markdown("""
**Tips:**
- Lower thinking tokens (1024-2048) for faster, direct solutions
- Higher thinking tokens (4096-8192) for complex reasoning
- Temperature 0.6 balances creativity and accuracy
- Automatic loop detection and truncation
""")
generate_btn = gr.Button("πŸš€ Generate Solution", variant="primary", size="lg")
clear_btn = gr.Button("πŸ—‘οΈ Clear", size="sm")
with gr.Column(scale=2):
output_html = gr.HTML(label="Solution")
generate_btn.click(
fn=generate_solution,
inputs=[prompt_input, temperature_slider, max_tokens_slider, max_thinking_slider],
outputs=output_html
)
clear_btn.click(
fn=lambda: ("", ""),
outputs=[prompt_input, output_html]
)
gr.Examples(
examples=[
["Write a Python function to find the maximum sum of a contiguous subarray (Kadane's Algorithm). Include edge cases and test with array [-2,1,-3,4,-1,2,1,-5,4]"],
["Implement a function to detect if a linked list has a cycle. Explain your approach and provide the solution."],
["Given an array of integers and a target sum, find two numbers that add up to the target. Optimize for time complexity."],
["Create a single page HTML application that lets the user choose a color and generates a matching color palette."],
],
inputs=prompt_input
)
if __name__ == "__main__":
demo.launch()