Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import re | |
| import time | |
| class VibeThinkerModel: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.load_model() | |
| def load_model(self): | |
| """Load VibeThinker model with transformers""" | |
| try: | |
| print("π Loading VibeThinker-1.5B with transformers...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| "WeiboAI/VibeThinker-1.5B", | |
| trust_remote_code=True | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| "WeiboAI/VibeThinker-1.5B", | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| print(f"β Model loaded successfully on {self.device}") | |
| print(f"πΎ Model memory: ~{self.model.get_memory_footprint() / 1e9:.2f} GB") | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| raise | |
| def generate_response(self, prompt, temperature=0.6, max_new_tokens=8192, max_thinking_tokens=4096): | |
| """ | |
| Generate response with thinking length control | |
| Args: | |
| prompt: Input prompt | |
| temperature: Sampling temperature | |
| max_new_tokens: Maximum new tokens to generate | |
| max_thinking_tokens: Hint for reasoning depth (used in prompt) | |
| """ | |
| if not self.model or not self.tokenizer: | |
| return "Model not loaded!", 0, 0, 0 | |
| try: | |
| start_time = time.time() | |
| # Format prompt for competitive coding | |
| formatted_prompt = f"""<|im_start|>system | |
| You are a competitive programming expert. Provide clear, concise solutions to coding problems. | |
| Format your response as: | |
| 1. Brief analysis (2-3 sentences) | |
| 2. Solution approach | |
| 3. Implementation code | |
| 4. Test cases | |
| Keep reasoning under {max_thinking_tokens} tokens. Be direct and avoid repetition.<|im_end|> | |
| <|im_start|>user | |
| {prompt}<|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| # Tokenize input | |
| inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device) | |
| prompt_length = inputs.input_ids.shape[1] | |
| # Generate with appropriate parameters | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=0.95, | |
| top_k=50, | |
| do_sample=True, | |
| repetition_penalty=1.1, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # Decode output | |
| full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the assistant's response | |
| if "<|im_start|>assistant" in full_output: | |
| generated_text = full_output.split("<|im_start|>assistant")[-1].strip() | |
| else: | |
| generated_text = full_output[len(formatted_prompt):].strip() | |
| # Check for loops and truncate if needed | |
| if self._detect_loop(generated_text): | |
| generated_text = self._truncate_loop(generated_text) | |
| generated_text += "\n\nβ οΈ *[Repetitive content detected and truncated]*" | |
| generation_time = time.time() - start_time | |
| # Calculate token counts | |
| completion_length = outputs.shape[1] - prompt_length | |
| return generated_text, prompt_length, completion_length, generation_time | |
| except Exception as e: | |
| return f"Error during generation: {str(e)}", 0, 0, 0 | |
| def _detect_loop(self, text): | |
| """Detect if text contains repetitive loops""" | |
| words = text.split() | |
| if len(words) < 20: | |
| return False | |
| # Check if same phrase repeats 3+ times | |
| for length in [10, 15, 20]: | |
| if len(words) < length * 3: | |
| continue | |
| for i in range(len(words) - length * 3): | |
| phrase = ' '.join(words[i:i+length]) | |
| rest = ' '.join(words[i+length:]) | |
| if rest.count(phrase) >= 2: | |
| return True | |
| return False | |
| def _truncate_loop(self, text): | |
| """Truncate text at the start of detected loop""" | |
| words = text.split() | |
| for length in [10, 15, 20]: | |
| if len(words) < length * 2: | |
| continue | |
| for i in range(len(words) - length * 2): | |
| phrase = ' '.join(words[i:i+length]) | |
| rest_start = i + length | |
| rest = ' '.join(words[rest_start:]) | |
| if phrase in rest: | |
| return ' '.join(words[:rest_start]) | |
| return text | |
| def parse_model_output(text): | |
| """ | |
| Parse model output to separate thinking and final answer | |
| ONLY extract code from the final answer section, not from thinking | |
| """ | |
| loop_warning = "" | |
| if "[Repetitive content detected and truncated]" in text: | |
| loop_warning = "\n\nβ οΈ **Note**: Repetitive content was detected and removed" | |
| text = text.replace("β οΈ *[Repetitive content detected and truncated]*", "") | |
| # Try to find explicit thinking delimiters | |
| thinking_patterns = [ | |
| r'<think>(.*?)</think>', | |
| r'<thinking>(.*?)</thinking>', | |
| ] | |
| thinking_content = "" | |
| answer_content = text | |
| for pattern in thinking_patterns: | |
| match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) | |
| if match: | |
| thinking_content = match.group(1).strip() | |
| answer_content = re.sub(pattern, '', text, flags=re.DOTALL | re.IGNORECASE).strip() | |
| break | |
| # If no explicit thinking tags, try to detect reasoning section | |
| if not thinking_content: | |
| split_markers = [ | |
| r'(.*?)(?=\n\n(?:Solution|Here\'s|Implementation|Code|Final).*?:)', | |
| r'(.*?)(?=\n\n```)', # Before first code block | |
| ] | |
| for pattern in split_markers: | |
| match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) | |
| if match: | |
| potential_thinking = match.group(1).strip() | |
| if len(potential_thinking) > 100: | |
| thinking_lower = potential_thinking.lower() | |
| if any(word in thinking_lower for word in ['step', 'approach', 'idea', 'first', 'we can', 'let\'s']): | |
| thinking_content = potential_thinking | |
| answer_content = text[len(potential_thinking):].strip() | |
| break | |
| # Extract code blocks ONLY from answer_content | |
| code_pattern = r'```(\w+)?\n(.*?)```' | |
| code_blocks = re.findall(code_pattern, answer_content, re.DOTALL) | |
| # Extract final answer | |
| answer_match = re.search(r'\\boxed\{([^}]+)\}', answer_content) | |
| if answer_match: | |
| final_answer = f"**Final Answer:** {answer_match.group(1)}" | |
| else: | |
| final_answer = answer_content | |
| final_answer += loop_warning | |
| return thinking_content, final_answer, code_blocks | |
| def format_output_html(thinking, answer, code_blocks, prompt_tokens, completion_tokens, generation_time): | |
| """Format output as styled HTML""" | |
| total_tokens = prompt_tokens + completion_tokens | |
| thinking_tokens_est = len(thinking.split()) * 1.3 if thinking else 0 | |
| tokens_per_sec = completion_tokens / generation_time if generation_time > 0 else 0 | |
| # Build thinking section HTML - plain text only | |
| thinking_html = "" | |
| if thinking: | |
| thinking_escaped = thinking.replace('<', '<').replace('>', '>') | |
| thinking_html = f""" | |
| <details style="background: #f8f9fa; border: 2px solid #e9ecef; border-radius: 12px; padding: 20px; margin-bottom: 24px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);"> | |
| <summary style="cursor: pointer; font-weight: 600; font-size: 16px; color: #495057; user-select: none; display: flex; align-items: center; gap: 8px;"> | |
| <span style="font-size: 20px;">π§ </span> | |
| <span>Reasoning Process (~{int(thinking_tokens_est):,} tokens)</span> | |
| <span style="margin-left: auto; font-size: 12px; color: #6c757d;">Click to expand/collapse</span> | |
| </summary> | |
| <div style="margin-top: 16px; padding-top: 16px; border-top: 1px solid #dee2e6; color: #212529; line-height: 1.7; white-space: pre-wrap; font-size: 14px; font-family: 'SF Mono', Monaco, Consolas, monospace; background: #ffffff; padding: 16px; border-radius: 8px;"> | |
| {thinking_escaped} | |
| </div> | |
| </details> | |
| """ | |
| # Build code blocks HTML | |
| code_html = "" | |
| if code_blocks: | |
| code_blocks_html = "" | |
| for idx, (lang, code) in enumerate(code_blocks): | |
| lang_display = lang if lang else "code" | |
| code_id = f"code_{idx}" | |
| code_clean = code.strip() | |
| code_blocks_html += f""" | |
| <div style="margin-bottom: 16px; background: #1e1e1e; border-radius: 12px; overflow: hidden; box-shadow: 0 4px 6px rgba(0,0,0,0.1);"> | |
| <div style="background: #2d2d2d; padding: 12px 20px; color: #ffffff; font-weight: 600; font-size: 13px; display: flex; justify-content: space-between; align-items: center; border-bottom: 1px solid #3d3d3d;"> | |
| <span>{lang_display}</span> | |
| <div style="display: flex; gap: 8px;"> | |
| <button onclick="navigator.clipboard.writeText(document.getElementById('{code_id}').textContent)" | |
| style="background: #4CAF50; color: white; border: none; padding: 6px 14px; border-radius: 6px; cursor: pointer; font-size: 12px; font-weight: 500; transition: background 0.2s;" | |
| onmouseover="this.style.background='#45a049'" | |
| onmouseout="this.style.background='#4CAF50'"> | |
| π Copy | |
| </button> | |
| <button onclick="downloadCode(document.getElementById('{code_id}').textContent, '{lang_display}')" | |
| style="background: #2196F3; color: white; border: none; padding: 6px 14px; border-radius: 6px; cursor: pointer; font-size: 12px; font-weight: 500; transition: background 0.2s;" | |
| onmouseover="this.style.background='#0b7dda'" | |
| onmouseout="this.style.background='#2196F3'"> | |
| πΎ Download | |
| </button> | |
| </div> | |
| </div> | |
| <pre style="margin: 0; padding: 20px; color: #d4d4d4; overflow-x: auto; font-family: 'SF Mono', Monaco, Consolas, monospace; font-size: 14px; line-height: 1.6;"><code id="{code_id}">{code_clean}</code></pre> | |
| </div> | |
| """ | |
| code_html = f""" | |
| <div style="margin-top: 24px;"> | |
| <h3 style="color: #1a1a1a; font-size: 18px; font-weight: 600; margin-bottom: 16px; display: flex; align-items: center; gap: 8px;"> | |
| <span style="font-size: 22px;">π»</span> Code | |
| </h3> | |
| {code_blocks_html} | |
| </div> | |
| <script> | |
| function downloadCode(code, lang) {{ | |
| const extensions = {{ | |
| 'python': 'py', 'javascript': 'js', 'java': 'java', | |
| 'cpp': 'cpp', 'c': 'c', 'html': 'html', 'css': 'css', | |
| 'typescript': 'ts', 'rust': 'rs', 'go': 'go', | |
| }}; | |
| const ext = extensions[lang.toLowerCase()] || 'txt'; | |
| const filename = `solution.${{ext}}`; | |
| const blob = new Blob([code], {{ type: 'text/plain' }}); | |
| const url = window.URL.createObjectURL(blob); | |
| const a = document.createElement('a'); | |
| a.href = url; | |
| a.download = filename; | |
| document.body.appendChild(a); | |
| a.click(); | |
| document.body.removeChild(a); | |
| window.URL.revokeObjectURL(url); | |
| }} | |
| </script> | |
| """ | |
| html = f""" | |
| <div style="font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; max-width: 100%; margin: 0 auto; background: #ffffff; color: #1a1a1a;"> | |
| <!-- Stats --> | |
| <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 12px; margin-bottom: 24px; color: white; box-shadow: 0 4px 6px rgba(0,0,0,0.1);"> | |
| <h3 style="margin: 0 0 12px 0; font-size: 18px; font-weight: 600;">π Generation Stats</h3> | |
| <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(140px, 1fr)); gap: 12px; font-size: 14px;"> | |
| <div style="background: rgba(255,255,255,0.2); padding: 12px; border-radius: 8px;"> | |
| <div style="opacity: 0.9; font-size: 12px; margin-bottom: 4px;">Time</div> | |
| <div style="font-size: 20px; font-weight: bold;">{generation_time:.1f}s</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.2); padding: 12px; border-radius: 8px;"> | |
| <div style="opacity: 0.9; font-size: 12px; margin-bottom: 4px;">Speed</div> | |
| <div style="font-size: 20px; font-weight: bold;">{tokens_per_sec:.1f} t/s</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.2); padding: 12px; border-radius: 8px;"> | |
| <div style="opacity: 0.9; font-size: 12px; margin-bottom: 4px;">Prompt</div> | |
| <div style="font-size: 20px; font-weight: bold;">{prompt_tokens:,}</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.2); padding: 12px; border-radius: 8px;"> | |
| <div style="opacity: 0.9; font-size: 12px; margin-bottom: 4px;">Output</div> | |
| <div style="font-size: 20px; font-weight: bold;">{completion_tokens:,}</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.2); padding: 12px; border-radius: 8px;"> | |
| <div style="opacity: 0.9; font-size: 12px; margin-bottom: 4px;">Thinking</div> | |
| <div style="font-size: 20px; font-weight: bold;">~{int(thinking_tokens_est):,}</div> | |
| </div> | |
| <div style="background: rgba(255,255,255,0.2); padding: 12px; border-radius: 8px;"> | |
| <div style="opacity: 0.9; font-size: 12px; margin-bottom: 4px;">Total</div> | |
| <div style="font-size: 20px; font-weight: bold;">{total_tokens:,}</div> | |
| </div> | |
| </div> | |
| </div> | |
| {thinking_html} | |
| <!-- Answer --> | |
| <div style="background: #ffffff; border: 2px solid #28a745; border-radius: 12px; padding: 24px; margin-bottom: 24px; box-shadow: 0 2px 4px rgba(40,167,69,0.1);"> | |
| <h3 style="margin: 0 0 16px 0; color: #28a745; font-size: 18px; font-weight: 600; display: flex; align-items: center; gap: 8px;"> | |
| <span style="font-size: 22px;">β </span> Final Solution | |
| </h3> | |
| <div style="color: #212529; line-height: 1.8; font-size: 15px; white-space: pre-wrap;"> | |
| {answer} | |
| </div> | |
| </div> | |
| {code_html} | |
| </div> | |
| """ | |
| return html | |
| # Initialize model | |
| print("π Initializing VibeThinker-1.5B...") | |
| vibe_model = VibeThinkerModel() | |
| def generate_solution(prompt, temperature=0.6, max_tokens=8192, max_thinking_tokens=4096, progress=gr.Progress()): | |
| """Generate and format solution with progress tracking""" | |
| if not prompt.strip(): | |
| return "<p style='color: #dc3545; font-size: 16px; padding: 20px;'>β οΈ Please enter a problem to solve.</p>" | |
| progress(0, desc="π Initializing...") | |
| progress(0.2, desc="π§ Generating solution...") | |
| response, prompt_tokens, completion_tokens, gen_time = vibe_model.generate_response( | |
| prompt, | |
| temperature=temperature, | |
| max_new_tokens=max_tokens, | |
| max_thinking_tokens=max_thinking_tokens | |
| ) | |
| progress(0.8, desc="π Formatting output...") | |
| thinking, answer, code_blocks = parse_model_output(response) | |
| html_output = format_output_html(thinking, answer, code_blocks, prompt_tokens, completion_tokens, gen_time) | |
| progress(1.0, desc="β Complete!") | |
| return html_output | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="purple"), | |
| css=".gradio-container { max-width: 1400px !important; }" | |
| ) as demo: | |
| gr.Markdown(""" | |
| # π§ VibeThinker-1.5B Competitive Coding Assistant | |
| **Optimized for**: Competitive programming (LeetCode, Codeforces, AtCoder) and algorithm challenges | |
| π― **Best for**: Python algorithmic problems with clear input/output specifications | |
| β οΈ **Note**: This model is specialized for competitive programming, not general software development | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox( | |
| label="π Your Coding Problem", | |
| placeholder="Example: Write a Python function to find the longest palindromic substring in a given string. Include test cases.", | |
| lines=8 | |
| ) | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.6, step=0.1, | |
| label="π‘οΈ Temperature (0.6 recommended)" | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=1024, maximum=16384, value=8192, step=1024, | |
| label="π Max New Tokens" | |
| ) | |
| max_thinking_slider = gr.Slider( | |
| minimum=512, maximum=8192, value=3072, step=512, | |
| label="π§ Max Thinking Tokens (hint for prompt)" | |
| ) | |
| gr.Markdown(""" | |
| **Tips:** | |
| - Lower thinking tokens (1024-2048) for faster, direct solutions | |
| - Higher thinking tokens (4096-8192) for complex reasoning | |
| - Temperature 0.6 balances creativity and accuracy | |
| - Automatic loop detection and truncation | |
| """) | |
| generate_btn = gr.Button("π Generate Solution", variant="primary", size="lg") | |
| clear_btn = gr.Button("ποΈ Clear", size="sm") | |
| with gr.Column(scale=2): | |
| output_html = gr.HTML(label="Solution") | |
| generate_btn.click( | |
| fn=generate_solution, | |
| inputs=[prompt_input, temperature_slider, max_tokens_slider, max_thinking_slider], | |
| outputs=output_html | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", ""), | |
| outputs=[prompt_input, output_html] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Write a Python function to find the maximum sum of a contiguous subarray (Kadane's Algorithm). Include edge cases and test with array [-2,1,-3,4,-1,2,1,-5,4]"], | |
| ["Implement a function to detect if a linked list has a cycle. Explain your approach and provide the solution."], | |
| ["Given an array of integers and a target sum, find two numbers that add up to the target. Optimize for time complexity."], | |
| ["Create a single page HTML application that lets the user choose a color and generates a matching color palette."], | |
| ], | |
| inputs=prompt_input | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |