#!/usr/bin/env python3 """ Step Audio R1 vLLM Gradio Interface """ import base64 import json import os import io from pydub import AudioSegment import gradio as gr import httpx API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:9999/v1") MODEL_NAME = os.getenv("MODEL_NAME", "Step-Audio-R1") def escape_html(text): """Escape HTML special characters to prevent XSS""" if not isinstance(text, str): return text return (text .replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) .replace("'", "'")) def process_audio(audio_path): """ Process audio: convert to wav, split if > 25s. Returns a list of base64 encoded wav strings. """ if not audio_path or not os.path.exists(audio_path): return [] try: # Load audio (pydub handles mp3, wav, etc. automatically if ffmpeg is installed) audio = AudioSegment.from_file(audio_path) # Split into chunks of 25 seconds (25000 ms) chunk_length_ms = 25000 chunks = [] if len(audio) > chunk_length_ms: for i in range(0, len(audio), chunk_length_ms): chunk = audio[i:i + chunk_length_ms] chunks.append(chunk) else: chunks.append(audio) # Convert chunks to base64 wav audio_data_list = [] for chunk in chunks: buffer = io.BytesIO() chunk.export(buffer, format="wav") encoded = base64.b64encode(buffer.getvalue()).decode() audio_data_list.append(encoded) return audio_data_list except Exception as e: print(f"[DEBUG] Audio processing error: {e}") return [] def format_messages(system, history, user_text, audio_data_list=None): """Format message list""" messages = [] if system: messages.append({"role": "system", "content": system}) if not history: history = [] # 处理历史记录 for item in history: role = item.get("role") if isinstance(item, dict) else getattr(item, "role", None) content = item.get("content") if isinstance(item, dict) else getattr(item, "content", None) if not role or content is None: continue # If content contains thinking process (with thinking-block div), extract only the response part if role == "assistant" and isinstance(content, str) and '
' in content: # Find the end of the thinking block and extract what comes after # Match the entire thinking block pattern = r'
.*?
\s*
\s*' remaining_content = re.sub(pattern, '', content, flags=re.DOTALL).strip() # If there's meaningful content after the thinking block, use it if remaining_content and not remaining_content.startswith('<'): content = remaining_content else: # Still in thinking phase or no response yet, skip continue # Check for Audio is_audio = isinstance(content, dict) and content.get("component") == "audio" if is_audio: audio_path = content["value"]["path"] if audio_path and os.path.exists(audio_path): try: item_audio_data_list = process_audio(audio_path) new_content = [] for audio_data in item_audio_data_list: new_content.append({ "type": "input_audio", "input_audio": { "data": audio_data, "format": "wav" } }) messages.append({"role": role, "content": new_content}) except Exception as e: print(f"[ERROR] Failed to process history audio: {e}") elif isinstance(content, str): messages.append({"role": role, "content": content}) elif isinstance(content, list): # Process list items and ensure text comes before audio text_items = [] audio_items = [] other_items = [] for c in content: # Check for Audio in list is_c_audio = isinstance(c, dict) and c.get('component') == "audio" if is_c_audio: audio_path = c["value"]["path"] if audio_path and os.path.exists(audio_path): try: item_audio_data_list = process_audio(audio_path) for audio_data in item_audio_data_list: audio_items.append({ "type": "input_audio", "input_audio": { "data": audio_data, "format": "wav" } }) except Exception as e: print(f"[ERROR] Failed to process history audio in list: {e}") elif isinstance(c, str): text_items.append({"type": "text", "text": c}) elif isinstance(c, dict): # Distinguish between text and audio types if c.get("type") == "text": text_items.append(c) elif c.get("type") == "input_audio": audio_items.append(c) else: other_items.append(c) # Combine: text first, then audio, then others safe_content = text_items + audio_items + other_items if safe_content: messages.append({"role": role, "content": safe_content}) # 添加当前用户消息(文本在前,音频在后) if user_text and audio_data_list: content = [] # 先添加文本 content.append({ "type": "text", "text": user_text }) # 再添加音频 for audio_data in audio_data_list: content.append({ "type": "input_audio", "input_audio": { "data": audio_data, "format": "wav" } }) messages.append({ "role": "user", "content": content }) elif user_text: messages.append({"role": "user", "content": user_text}) elif audio_data_list: content = [] for audio_data in audio_data_list: content.append({ "type": "input_audio", "input_audio": { "data": audio_data, "format": "wav" } }) messages.append({ "role": "user", "content": content }) return messages def chat(system_prompt, user_text, audio_file, history, max_tokens, temperature, top_p, show_thinking=True, model_name=None): """Chat function""" # If model is not specified, use global configuration if model_name is None: model_name = MODEL_NAME if not user_text and not audio_file: yield history or [] return # Ensure history is a list and formatted correctly history = history or [] clean_history = [] for item in history: if isinstance(item, dict) and 'role' in item and 'content' in item: clean_history.append(item) elif hasattr(item, "role") and hasattr(item, "content"): # Keep ChatMessage object clean_history.append(item) history = clean_history # Process audio audio_data_list = [] if audio_file: audio_data_list = process_audio(audio_file) messages = format_messages(system_prompt, history, user_text, audio_data_list) if not messages: yield history or [] return # Debug: Print message format debug_messages = [] for msg in messages: if isinstance(msg, dict) and isinstance(msg.get("content"), list): new_content = [] for item in msg["content"]: if isinstance(item, dict) and item.get("type") == "input_audio": item_copy = item.copy() if "input_audio" in item_copy: audio_info = item_copy["input_audio"].copy() if "data" in audio_info: audio_info["data"] = f"[BASE64_AUDIO_DATA_LEN_{len(audio_info['data'])}]" item_copy["input_audio"] = audio_info new_content.append(item_copy) else: new_content.append(item) msg_copy = msg.copy() msg_copy["content"] = new_content debug_messages.append(msg_copy) else: debug_messages.append(msg) print(f"[DEBUG] Messages to API: {json.dumps(debug_messages, ensure_ascii=False, indent=2)}") # Update history with user message immediately (text first, then audio) if user_text and audio_file: # 1. Add text message first history.append({"role": "user", "content": user_text}) # 2. Add audio message second history.append({"role": "user", "content": gr.Audio(audio_file)}) elif user_text: # Text only history.append({"role": "user", "content": user_text}) elif audio_file: # Audio only history.append({"role": "user", "content": gr.Audio(audio_file)}) # Add thinking placeholder if show_thinking: history.append({ "role": "assistant", "content": ( '
\n' '
💭 Thinking...
\n' '
Processing your request...
\n' '
' ) }) yield history else: history.append({ "role": "assistant", "content": "⏳ Generating response..." }) yield history try: # 禁用代理以访问内网 API with httpx.Client(base_url=API_BASE_URL, timeout=120, proxies={}) as client: response = client.post("/chat/completions", json={ "model": model_name, "messages": messages, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, "stream": True, "repetition_penalty": 1.0, "stop_token_ids": [151665] }) if response.status_code != 200: error_msg = f"❌ API Error {response.status_code}" if response.status_code == 404: error_msg += " - vLLM service not ready" elif response.status_code == 400: error_msg += " - Bad request" elif response.status_code == 500: error_msg += " - Model error" # Update the last message with error history[-1]["content"] = error_msg yield history return # Process streaming response buffer = "" is_thinking = True for line in response.iter_lines(): if not line: continue # Ensure line is string format if isinstance(line, bytes): line = line.decode('utf-8') else: line = str(line) if line.startswith('data: '): data_str = line[6:] if data_str.strip() == '[DONE]': break try: data = json.loads(data_str) if 'choices' in data and len(data['choices']) > 0: delta = data['choices'][0].get('delta', {}) if 'content' in delta: content = delta['content'] buffer += content if is_thinking: if "" in buffer: is_thinking = False parts = buffer.split("", 1) think_content = parts[0] response_content = parts[1] if think_content.startswith(""): think_content = think_content[len(""):].strip() if show_thinking: # Format thinking with custom styled block (escape HTML for safety) escaped_think = escape_html(think_content) formatted_content = ( f'
\n' f'
💭 Thinking Process
\n' f'
{escaped_think}
\n' f'
\n\n' f'{response_content}' ) history[-1]["content"] = formatted_content else: # Don't show thinking, replace with response message directly history[-1]["content"] = response_content else: # Update thinking message with collapsible format (only if showing) if show_thinking: current_think = buffer if current_think.startswith(""): current_think = current_think[len(""):].strip() escaped_think = escape_html(current_think) formatted_content = ( f'
\n' f'
💭 Thinking...
\n' f'
{escaped_think}
\n' f'
' ) history[-1]["content"] = formatted_content else: # Already split, update the combined message parts = buffer.split("
", 1) think_content = parts[0] response_content = parts[1] if think_content.startswith(""): think_content = think_content[len(""):].strip() if show_thinking: # Update with formatted thinking + response escaped_think = escape_html(think_content) formatted_content = ( f'
\n' f'
💭 Thinking Process
\n' f'
{escaped_think}
\n' f'
\n\n' f'{response_content}' ) history[-1]["content"] = formatted_content else: # Only show response history[-1]["content"] = response_content yield history except json.JSONDecodeError: continue except httpx.ConnectError: history[-1]["content"] = "❌ Cannot connect to vLLM API" yield history except Exception as e: history[-1]["content"] = f"❌ Error: {str(e)}" yield history # Custom CSS for better UI custom_css = """ /* 全局样式 */ .gradio-container { max-width: 100% !important; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; } /* 标题样式 */ .app-header { text-align: center; padding: 2.5rem 1.5rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); position: relative; overflow: hidden; border-radius: 16px; margin-bottom: 1.5rem; box-shadow: 0 8px 24px rgba(102, 126, 234, 0.35); } /* 标题背景装饰 */ .app-header::before { content: ''; position: absolute; top: -50%; right: -50%; width: 200%; height: 200%; background: radial-gradient(circle, rgba(255, 255, 255, 0.1) 0%, transparent 70%); animation: rotate 20s linear infinite; } @keyframes rotate { from { transform: rotate(0deg); } to { transform: rotate(360deg); } } .app-header h1 { margin: 0; font-size: 2.8rem; font-weight: 700; color: white !important; text-shadow: 0 3px 6px rgba(0, 0, 0, 0.25); letter-spacing: 1px; position: relative; z-index: 1; } .app-header p { color: rgba(255, 255, 255, 0.95) !important; text-shadow: 0 2px 4px rgba(0, 0, 0, 0.2); position: relative; z-index: 1; line-height: 1.5; } /* 聊天框样式 */ .chatbot-container { border-radius: 12px; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08); overflow: hidden; } /* 思考过程样式 - 模仿Claude/ChatGPT的风格 */ .thinking-block { background: linear-gradient(135deg, #f5f7fa 0%, #eef2f7 100%); border-left: 4px solid #667eea; padding: 16px 20px; margin: 12px 0; border-radius: 8px; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); } .thinking-header { display: flex; align-items: center; font-weight: 600; color: #667eea; margin-bottom: 10px; font-size: 0.95rem; } .thinking-content { background: #ffffff; padding: 12px 16px; border-radius: 6px; font-family: 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas, 'Courier New', monospace; font-size: 0.9rem; line-height: 1.6; color: #374151; white-space: pre-wrap; word-wrap: break-word; border: 1px solid #e5e7eb; } /* 回复分隔线 */ .response-divider { border: none; height: 2px; background: linear-gradient(to right, transparent, #e5e7eb, transparent); margin: 20px 0; } /* 按钮样式 */ .primary-btn { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; border: none !important; transition: all 0.3s ease !important; } .primary-btn:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important; } /* 左侧面板样式 */ .left-panel { background: #f9fafb; border-radius: 12px; padding: 1rem; height: 100%; } /* 输入框样式 */ .input-box textarea { border-radius: 8px !important; border: 2px solid #e5e7eb !important; transition: border-color 0.3s ease !important; } .input-box textarea:focus { border-color: #667eea !important; box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important; } /* 输入区域标题 */ h3 { color: #374151; font-size: 1.1rem; margin: 1rem 0 0.5rem 0; } /* 聊天消息样式优化 */ .message-wrap { padding: 1rem !important; } .message { padding: 1rem !important; border-radius: 12px !important; line-height: 1.6 !important; } /* 用户消息 */ .message.user { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; color: white !important; } /* 助手消息 */ .message.bot { background: #f9fafb !important; border: 1px solid #e5e7eb !important; } /* 左侧面板整体样式 */ .left-column { background: linear-gradient(to bottom, #ffffff 0%, #f9fafb 100%); border-radius: 12px; padding: 1rem; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); } /* 按钮容器样式 */ .button-row { margin-top: 1rem; gap: 0.5rem; } /* 滚动条美化 */ ::-webkit-scrollbar { width: 8px; height: 8px; } ::-webkit-scrollbar-track { background: #f1f1f1; border-radius: 4px; } ::-webkit-scrollbar-thumb { background: #888; border-radius: 4px; } ::-webkit-scrollbar-thumb:hover { background: #555; } """ # Gradio Interface with gr.Blocks(title="Step Audio R1", css=custom_css, theme=gr.themes.Soft()) as demo: # Header gr.HTML("""

🔊 Step-Audio-R1

Advanced Audio-Language Model with Reasoning

Comprehensive audio understanding: Speech, Sound, Music & Lyrics

""") with gr.Row(): # Left Panel - Input Area with gr.Column(scale=1, min_width=350): # Configuration with gr.Accordion("⚙️ Configuration", open=False): system_prompt = gr.Textbox( label="System Prompt", lines=2, value="You are a voice assistant with extensive experience in audio processing.", placeholder="Enter system prompt...", elem_classes=["input-box"] ) max_tokens = gr.Slider( 1, 7192, value=6400, label="Max Tokens", info="Maximum tokens to generate" ) temperature = gr.Slider( 0.0, 2.0, value=0.7, label="Temperature", info="Higher = more random" ) top_p = gr.Slider( 0.0, 1.0, value=0.9, label="Top P", info="Nucleus sampling" ) show_thinking = gr.Checkbox( label="💭 Show Thinking Process", value=True, info="Display reasoning steps" ) # Input Area gr.Markdown("### 📝 Your Input") user_text = gr.Textbox( label="Text Message", lines=4, placeholder="Type your message here...", elem_classes=["input-box"], show_label=False ) audio_file = gr.Audio( label="🎤 Audio Input", type="filepath", sources=["microphone", "upload"], show_label=True ) # Buttons with gr.Row(): clear_btn = gr.Button("🗑️ Clear", scale=1, size="lg") submit_btn = gr.Button( "🚀 Send", variant="primary", scale=2, size="lg", elem_classes=["primary-btn"] ) # Usage Guide at bottom with gr.Accordion("📖 Quick Guide", open=False): gr.Markdown(""" **Usage:** - Type text, upload audio, or both - Audio > 25s auto-splits - Toggle thinking process display **Tips:** - Thinking shown in blue gradient block - History auto-cleaned for API - Adjust params in Configuration """) # Right Panel - Conversation Area with gr.Column(scale=2): chatbot = gr.Chatbot( label="💬 Conversation", height=700, type="messages", elem_classes=["chatbot-container"], show_label=True, avatar_images=(None, None), bubble_full_width=False ) submit_btn.click( fn=chat, inputs=[system_prompt, user_text, audio_file, chatbot, max_tokens, temperature, top_p, show_thinking], outputs=[chatbot] ) clear_btn.click( fn=lambda: ([], "", None), outputs=[chatbot, user_text, audio_file] ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=6008) parser.add_argument("--model", default=MODEL_NAME) args = parser.parse_args() import os # 取消代理设置 os.environ.update({ 'http_proxy': '', 'https_proxy': '', 'all_proxy': '', 'HTTP_PROXY': '', 'HTTPS_PROXY': '', 'ALL_PROXY': '' }) # 更新全局模型名称 if args.model: MODEL_NAME = args.model print(f"启动Gradio: http://{args.host}:{args.port}") print(f"API地址: {API_BASE_URL}") print(f"模型: {MODEL_NAME}") demo.launch(server_name=args.host, server_port=args.port, share=False)