Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Step Audio R1 vLLM Gradio Interface | |
| """ | |
| import base64 | |
| import json | |
| import os | |
| import io | |
| import time | |
| from pydub import AudioSegment | |
| import re | |
| import gradio as gr | |
| import httpx | |
| API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:9999/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Step-Audio-R1") | |
| SECRET = os.getenv("API_SECRET", "") | |
| # 音频大小限制 (10MB) | |
| MAX_AUDIO_SIZE_MB = 10 | |
| MAX_AUDIO_SIZE_BYTES = MAX_AUDIO_SIZE_MB * 1024 * 1024 | |
| def get_wav_size(audio_path): | |
| """Calculate the size of audio after converting to wav (in bytes)""" | |
| if not audio_path or not os.path.exists(audio_path): | |
| return 0 | |
| try: | |
| audio = AudioSegment.from_file(audio_path) | |
| buffer = io.BytesIO() | |
| audio.export(buffer, format="wav") | |
| return len(buffer.getvalue()) | |
| except Exception as e: | |
| print(f"[ERROR] Failed to calculate wav size: {e}") | |
| return 0 | |
| def get_audio_size_info(used_size_bytes, current_audio_path=None): | |
| """Get audio size usage info message""" | |
| current_size = 0 | |
| if current_audio_path and os.path.exists(current_audio_path): | |
| current_size = get_wav_size(current_audio_path) | |
| remaining = MAX_AUDIO_SIZE_BYTES - used_size_bytes | |
| used_mb = used_size_bytes / (1024 * 1024) | |
| remaining_mb = remaining / (1024 * 1024) | |
| current_mb = current_size / (1024 * 1024) | |
| if used_size_bytes == 0 and current_size == 0: | |
| return f"📊 Audio limit: {MAX_AUDIO_SIZE_MB}MB total available" | |
| elif current_size > 0: | |
| new_remaining = remaining - current_size | |
| new_remaining_mb = new_remaining / (1024 * 1024) | |
| if new_remaining < 0: | |
| return f"📊 ⚠️ Current audio ({current_mb:.2f}MB) exceeds remaining limit ({remaining_mb:.2f}MB)" | |
| return f"📊 Audio: {used_mb:.2f}MB used + {current_mb:.2f}MB pending = {new_remaining_mb:.2f}MB remaining" | |
| else: | |
| return f"📊 Audio limit: {used_mb:.2f}MB used, {remaining_mb:.2f}MB remaining (max {MAX_AUDIO_SIZE_MB}MB)" | |
| def escape_html(text): | |
| """Escape HTML special characters to prevent XSS""" | |
| if not isinstance(text, str): | |
| return text | |
| return (text | |
| .replace("&", "&") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| .replace('"', """) | |
| .replace("'", "'")) | |
| def process_audio(audio_path): | |
| """ | |
| Process audio: convert to wav, split if > 25s. | |
| Returns a list of base64 encoded wav strings. | |
| """ | |
| if not audio_path or not os.path.exists(audio_path): | |
| return [] | |
| try: | |
| # Load audio (pydub handles mp3, wav, etc. automatically if ffmpeg is installed) | |
| audio = AudioSegment.from_file(audio_path) | |
| # Split into chunks of 25 seconds (25000 ms) | |
| chunk_length_ms = 25000 | |
| chunks = [] | |
| if len(audio) > chunk_length_ms: | |
| for i in range(0, len(audio), chunk_length_ms): | |
| chunk = audio[i:i + chunk_length_ms] | |
| chunks.append(chunk) | |
| else: | |
| chunks.append(audio) | |
| # Convert chunks to base64 wav | |
| audio_data_list = [] | |
| for chunk in chunks: | |
| buffer = io.BytesIO() | |
| chunk.export(buffer, format="wav") | |
| encoded = base64.b64encode(buffer.getvalue()).decode() | |
| audio_data_list.append(encoded) | |
| return audio_data_list | |
| except Exception as e: | |
| print(f"[DEBUG] Audio processing error: {e}") | |
| return [] | |
| def format_messages(system, history, user_text, audio_data_list=None): | |
| """Format message list""" | |
| messages = [] | |
| if system: | |
| messages.append({"role": "system", "content": system}) | |
| if not history: | |
| history = [] | |
| # 处理历史记录 | |
| for item in history: | |
| role = item.get("role") if isinstance(item, dict) else getattr(item, "role", None) | |
| content = item.get("content") if isinstance(item, dict) else getattr(item, "content", None) | |
| if not role or content is None: | |
| continue | |
| # If content contains thinking process (with thinking-block div), extract only the response part | |
| if role == "assistant" and isinstance(content, str) and '<div class="thinking-block">' in content: | |
| # Find the end of the thinking block and extract what comes after | |
| # Match the entire thinking block | |
| pattern = r'<div class="thinking-block">.*?</div>\s*</div>\s*' | |
| remaining_content = re.sub(pattern, '', content, flags=re.DOTALL).strip() | |
| # If there's meaningful content after the thinking block, use it | |
| if remaining_content and not remaining_content.startswith('<'): | |
| content = remaining_content | |
| else: | |
| # Still in thinking phase or no response yet, skip | |
| continue | |
| # Check for Audio | |
| is_audio = isinstance(content, dict) and content.get("component") == "audio" | |
| if is_audio: | |
| audio_path = content["value"]["path"] | |
| if audio_path and os.path.exists(audio_path): | |
| try: | |
| item_audio_data_list = process_audio(audio_path) | |
| new_content = [] | |
| for audio_data in item_audio_data_list: | |
| new_content.append({ | |
| "type": "input_audio", | |
| "input_audio": { | |
| "data": "data:audio/wav;base64," + audio_data, | |
| "format": "wav" | |
| } | |
| }) | |
| messages.append({"role": role, "content": new_content}) | |
| except Exception as e: | |
| print(f"[ERROR] Failed to process history audio: {e}") | |
| elif isinstance(content, str): | |
| messages.append({"role": role, "content": content}) | |
| elif isinstance(content, list): | |
| # Process list items and ensure text comes before audio | |
| text_items = [] | |
| audio_items = [] | |
| other_items = [] | |
| for c in content: | |
| # Check for Audio in list | |
| is_c_audio = isinstance(c, dict) and c.get('component') == "audio" | |
| if is_c_audio: | |
| audio_path = c["value"]["path"] | |
| if audio_path and os.path.exists(audio_path): | |
| try: | |
| item_audio_data_list = process_audio(audio_path) | |
| for audio_data in item_audio_data_list: | |
| audio_items.append({ | |
| "type": "input_audio", | |
| "input_audio": { | |
| "data": "data:audio/wav;base64," + audio_data, | |
| "format": "wav" | |
| } | |
| }) | |
| except Exception as e: | |
| print(f"[ERROR] Failed to process history audio in list: {e}") | |
| elif isinstance(c, str): | |
| text_items.append({"type": "text", "text": c}) | |
| elif isinstance(c, dict): | |
| # Distinguish between text and audio types | |
| if c.get("type") == "text": | |
| text_items.append(c) | |
| elif c.get("type") == "input_audio": | |
| audio_items.append(c) | |
| else: | |
| other_items.append(c) | |
| # Combine: text first, then audio, then others | |
| safe_content = text_items + audio_items + other_items | |
| if safe_content: | |
| messages.append({"role": role, "content": safe_content}) | |
| # 添加当前用户消息(文本在前,音频在后) | |
| if user_text and audio_data_list: | |
| content = [] | |
| # 先添加文本 | |
| content.append({ | |
| "type": "text", | |
| "text": user_text | |
| }) | |
| # 再添加音频 | |
| for audio_data in audio_data_list: | |
| content.append({ | |
| "type": "input_audio", | |
| "input_audio": { | |
| "data": "data:audio/wav;base64," + audio_data, | |
| "format": "wav" | |
| } | |
| }) | |
| messages.append({ | |
| "role": "user", | |
| "content": content | |
| }) | |
| elif user_text: | |
| messages.append({"role": "user", "content": user_text}) | |
| elif audio_data_list: | |
| content = [] | |
| for audio_data in audio_data_list: | |
| content.append({ | |
| "type": "input_audio", | |
| "input_audio": { | |
| "data": "data:audio/wav;base64," + audio_data, | |
| "format": "wav" | |
| } | |
| }) | |
| messages.append({ | |
| "role": "user", | |
| "content": content | |
| }) | |
| return messages | |
| def chat(system_prompt, user_text, audio_file, history, used_audio_size, max_tokens, temperature, top_p, show_thinking=True, model_name=None): | |
| """Chat function""" | |
| # If model is not specified, use global configuration | |
| if model_name is None: | |
| model_name = MODEL_NAME | |
| # 初始化已使用音频大小 | |
| if used_audio_size is None: | |
| used_audio_size = 0 | |
| if not user_text and not audio_file: | |
| yield history or [], used_audio_size, get_audio_size_info(used_audio_size, None) | |
| return | |
| # 检查音频大小限制 | |
| current_audio_size = 0 | |
| if audio_file: | |
| current_audio_size = get_wav_size(audio_file) | |
| total_size = used_audio_size + current_audio_size | |
| if total_size > MAX_AUDIO_SIZE_BYTES: | |
| history = history or [] | |
| remaining_mb = (MAX_AUDIO_SIZE_BYTES - used_audio_size) / (1024 * 1024) | |
| current_mb = current_audio_size / (1024 * 1024) | |
| error_msg = f"❌ Audio size limit exceeded! Current audio is {current_mb:.2f}MB, but only {max(0, remaining_mb):.2f}MB remaining (max {MAX_AUDIO_SIZE_MB}MB)" | |
| history.append({"role": "assistant", "content": error_msg}) | |
| yield history, used_audio_size, get_audio_size_info(used_audio_size, None) | |
| return | |
| # Ensure history is a list and formatted correctly | |
| history = history or [] | |
| clean_history = [] | |
| for item in history: | |
| if isinstance(item, dict) and 'role' in item and 'content' in item: | |
| clean_history.append(item) | |
| elif hasattr(item, "role") and hasattr(item, "content"): | |
| # Keep ChatMessage object | |
| clean_history.append(item) | |
| history = clean_history | |
| # Process audio | |
| audio_data_list = [] | |
| if audio_file: | |
| audio_data_list = process_audio(audio_file) | |
| messages = format_messages(system_prompt, history, user_text, audio_data_list) | |
| if not messages: | |
| yield history or [], used_audio_size, get_audio_size_info(used_audio_size, None) | |
| return | |
| # Debug: Print message format | |
| debug_messages = [] | |
| for msg in messages: | |
| if isinstance(msg, dict) and isinstance(msg.get("content"), list): | |
| new_content = [] | |
| for item in msg["content"]: | |
| if isinstance(item, dict) and item.get("type") == "input_audio": | |
| item_copy = item.copy() | |
| if "input_audio" in item_copy: | |
| audio_info = item_copy["input_audio"].copy() | |
| if "data" in audio_info: | |
| data_len = len(audio_info['data']) | |
| if data_len >= 1024 * 1024: | |
| human_size = f"{data_len / (1024 * 1024):.2f} MB" | |
| elif data_len >= 1024: | |
| human_size = f"{data_len / 1024:.2f} KB" | |
| else: | |
| human_size = f"{data_len} B" | |
| audio_info["data"] = f"[BASE64_AUDIO_DATA: {human_size} ({data_len} bytes)]" | |
| item_copy["input_audio"] = audio_info | |
| new_content.append(item_copy) | |
| else: | |
| new_content.append(item) | |
| msg_copy = msg.copy() | |
| msg_copy["content"] = new_content | |
| debug_messages.append(msg_copy) | |
| else: | |
| debug_messages.append(msg) | |
| print(f"[DEBUG] Messages to API: {json.dumps(debug_messages, ensure_ascii=False, indent=2)}") | |
| # Update history with user message immediately (text first, then audio) | |
| if user_text and audio_file: | |
| # 1. Add text message first | |
| history.append({"role": "user", "content": user_text}) | |
| # 2. Add audio message second | |
| history.append({"role": "user", "content": gr.Audio(audio_file)}) | |
| elif user_text: | |
| # Text only | |
| history.append({"role": "user", "content": user_text}) | |
| elif audio_file: | |
| # Audio only | |
| history.append({"role": "user", "content": gr.Audio(audio_file)}) | |
| # 更新已使用的音频大小 | |
| new_used_audio_size = used_audio_size + current_audio_size | |
| # Add thinking placeholder | |
| if show_thinking: | |
| history.append({ | |
| "role": "assistant", | |
| "content": ( | |
| '<div class="thinking-block">\n' | |
| '<div class="thinking-header">💭 Thinking...</div>\n' | |
| '<div class="thinking-content">Processing your request...</div>\n' | |
| '</div>' | |
| ) | |
| }) | |
| yield history, new_used_audio_size, get_audio_size_info(new_used_audio_size, None) | |
| else: | |
| history.append({ | |
| "role": "assistant", | |
| "content": "⏳ Generating response..." | |
| }) | |
| yield history, new_used_audio_size, get_audio_size_info(new_used_audio_size, None) | |
| try: | |
| # 禁用代理以访问内网 API | |
| start_time = time.time() | |
| print(f"[API] Starting request to {API_BASE_URL}/chat/completions ...") | |
| with httpx.Client(base_url=API_BASE_URL, timeout=120) as client: | |
| response = client.post("/chat/completions", json={ | |
| "model": model_name, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "stream": True, | |
| "repetition_penalty": 1.0, | |
| "stop_token_ids": [151665] | |
| }, headers={ | |
| "Authorization": f"Bearer {SECRET}", | |
| }) | |
| if response.status_code != 200: | |
| elapsed_time = time.time() - start_time | |
| print(f"[API] ❌ FAILED - Status: {response.status_code}, Time: {elapsed_time:.2f}s") | |
| error_msg = f"❌ API Error {response.status_code}" | |
| if response.status_code == 404: | |
| error_msg += " - vLLM service not ready" | |
| elif response.status_code == 400: | |
| error_msg += f" - Bad request ({response.text})" | |
| elif response.status_code == 500: | |
| error_msg += f" - Model error ({response.text})" | |
| # Update the last message with error | |
| history[-1]["content"] = error_msg | |
| yield history, new_used_audio_size, get_audio_size_info(new_used_audio_size, None) | |
| return | |
| # Process streaming response | |
| buffer = "" | |
| is_thinking = True | |
| for line in response.iter_lines(): | |
| if not line: | |
| continue | |
| # Ensure line is string format | |
| if isinstance(line, bytes): | |
| line = line.decode('utf-8') | |
| else: | |
| line = str(line) | |
| if line.startswith('data: '): | |
| data_str = line[6:] | |
| if data_str.strip() == '[DONE]': | |
| break | |
| try: | |
| data = json.loads(data_str) | |
| if 'choices' in data and len(data['choices']) > 0: | |
| delta = data['choices'][0].get('delta', {}) | |
| if 'content' in delta: | |
| content = delta['content'] | |
| buffer += content | |
| if is_thinking: | |
| if "</think>" in buffer: | |
| is_thinking = False | |
| parts = buffer.split("</think>", 1) | |
| think_content = parts[0] | |
| response_content = parts[1] | |
| if think_content.startswith("<think>"): | |
| think_content = think_content[len("<think>"):].strip() | |
| if show_thinking: | |
| # Format thinking with custom styled block (escape HTML for safety) | |
| escaped_think = escape_html(think_content) | |
| formatted_content = ( | |
| f'<div class="thinking-block">\n' | |
| f'<div class="thinking-header">💭 Thinking Process</div>\n' | |
| f'<div class="thinking-content">{escaped_think}</div>\n' | |
| f'</div>\n\n' | |
| f'{response_content}' | |
| ) | |
| history[-1]["content"] = formatted_content | |
| else: | |
| # Don't show thinking, replace with response message directly | |
| history[-1]["content"] = response_content | |
| else: | |
| # Update thinking message with collapsible format (only if showing) | |
| if show_thinking: | |
| current_think = buffer | |
| if current_think.startswith("<think>"): | |
| current_think = current_think[len("<think>"):].strip() | |
| escaped_think = escape_html(current_think) | |
| formatted_content = ( | |
| f'<div class="thinking-block">\n' | |
| f'<div class="thinking-header">💭 Thinking...</div>\n' | |
| f'<div class="thinking-content">{escaped_think}</div>\n' | |
| f'</div>' | |
| ) | |
| history[-1]["content"] = formatted_content | |
| else: | |
| # Already split, update the combined message | |
| parts = buffer.split("</think>", 1) | |
| think_content = parts[0] | |
| response_content = parts[1] | |
| if think_content.startswith("<think>"): | |
| think_content = think_content[len("<think>"):].strip() | |
| if show_thinking: | |
| # Update with formatted thinking + response | |
| escaped_think = escape_html(think_content) | |
| formatted_content = ( | |
| f'<div class="thinking-block">\n' | |
| f'<div class="thinking-header">💭 Thinking Process</div>\n' | |
| f'<div class="thinking-content">{escaped_think}</div>\n' | |
| f'</div>\n\n' | |
| f'{response_content}' | |
| ) | |
| history[-1]["content"] = formatted_content | |
| else: | |
| # Only show response | |
| history[-1]["content"] = response_content | |
| yield history, new_used_audio_size, get_audio_size_info(new_used_audio_size, None) | |
| except json.JSONDecodeError: | |
| continue | |
| # 请求成功完成 | |
| elapsed_time = time.time() - start_time | |
| print(f"[API] ✅ SUCCESS - Time: {elapsed_time:.2f}s") | |
| except httpx.ConnectError: | |
| elapsed_time = time.time() - start_time | |
| print(f"[API] ❌ FAILED - Connection error, Time: {elapsed_time:.2f}s") | |
| history[-1]["content"] = "❌ Cannot connect to vLLM API" | |
| yield history, new_used_audio_size, get_audio_size_info(new_used_audio_size, None) | |
| except Exception as e: | |
| elapsed_time = time.time() - start_time | |
| print(f"[API] ❌ FAILED - Error: {str(e)}, Time: {elapsed_time:.2f}s") | |
| history[-1]["content"] = f"❌ Error: {str(e)}" | |
| yield history, new_used_audio_size, get_audio_size_info(new_used_audio_size, None) | |
| # Custom CSS for better UI | |
| custom_css = """ | |
| /* 全局样式 */ | |
| .gradio-container { | |
| max-width: 100% !important; | |
| font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; | |
| } | |
| /* 标题样式 */ | |
| .app-header { | |
| text-align: center; | |
| padding: 2.5rem 1.5rem; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| position: relative; | |
| overflow: hidden; | |
| border-radius: 16px; | |
| margin-bottom: 1.5rem; | |
| box-shadow: 0 8px 24px rgba(102, 126, 234, 0.35); | |
| } | |
| /* 标题背景装饰 */ | |
| .app-header::before { | |
| content: ''; | |
| position: absolute; | |
| top: -50%; | |
| right: -50%; | |
| width: 200%; | |
| height: 200%; | |
| background: radial-gradient(circle, rgba(255, 255, 255, 0.1) 0%, transparent 70%); | |
| animation: rotate 20s linear infinite; | |
| } | |
| @keyframes rotate { | |
| from { transform: rotate(0deg); } | |
| to { transform: rotate(360deg); } | |
| } | |
| .app-header h1 { | |
| margin: 0; | |
| font-size: 2.8rem; | |
| font-weight: 700; | |
| color: white !important; | |
| text-shadow: 0 3px 6px rgba(0, 0, 0, 0.25); | |
| letter-spacing: 1px; | |
| position: relative; | |
| z-index: 1; | |
| } | |
| .app-header p { | |
| color: rgba(255, 255, 255, 0.95) !important; | |
| text-shadow: 0 2px 4px rgba(0, 0, 0, 0.2); | |
| position: relative; | |
| z-index: 1; | |
| line-height: 1.5; | |
| } | |
| /* 聊天框样式 */ | |
| .chatbot-container { | |
| border-radius: 12px; | |
| box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08); | |
| overflow: hidden; | |
| } | |
| /* 思考过程样式 - 模仿Claude/ChatGPT的风格 */ | |
| .thinking-block { | |
| background: linear-gradient(135deg, #f5f7fa 0%, #eef2f7 100%); | |
| border-left: 4px solid #667eea; | |
| padding: 16px 20px; | |
| margin: 12px 0; | |
| border-radius: 8px; | |
| box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); | |
| } | |
| .thinking-header { | |
| display: flex; | |
| align-items: center; | |
| font-weight: 600; | |
| color: #667eea; | |
| margin-bottom: 10px; | |
| font-size: 0.95rem; | |
| } | |
| .thinking-content { | |
| background: #ffffff; | |
| padding: 12px 16px; | |
| border-radius: 6px; | |
| font-family: 'SF Mono', Monaco, 'Cascadia Code', 'Roboto Mono', Consolas, 'Courier New', monospace; | |
| font-size: 0.9rem; | |
| line-height: 1.6; | |
| color: #374151; | |
| white-space: pre-wrap; | |
| word-wrap: break-word; | |
| border: 1px solid #e5e7eb; | |
| } | |
| /* 回复分隔线 */ | |
| .response-divider { | |
| border: none; | |
| height: 2px; | |
| background: linear-gradient(to right, transparent, #e5e7eb, transparent); | |
| margin: 20px 0; | |
| } | |
| /* 按钮样式 */ | |
| .primary-btn { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| border: none !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .primary-btn:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4) !important; | |
| } | |
| /* 左侧面板样式 */ | |
| .left-panel { | |
| background: #f9fafb; | |
| border-radius: 12px; | |
| padding: 1rem; | |
| height: 100%; | |
| } | |
| /* 输入框样式 */ | |
| .input-box textarea { | |
| border-radius: 8px !important; | |
| border: 2px solid #e5e7eb !important; | |
| transition: border-color 0.3s ease !important; | |
| } | |
| .input-box textarea:focus { | |
| border-color: #667eea !important; | |
| box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important; | |
| } | |
| /* 输入区域标题 */ | |
| h3 { | |
| color: #374151; | |
| font-size: 1.1rem; | |
| margin: 1rem 0 0.5rem 0; | |
| } | |
| /* 聊天消息样式优化 */ | |
| .message-wrap { | |
| padding: 1rem !important; | |
| } | |
| .message { | |
| padding: 1rem !important; | |
| border-radius: 12px !important; | |
| line-height: 1.6 !important; | |
| } | |
| /* 用户消息 */ | |
| .message.user { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: white !important; | |
| } | |
| /* 助手消息 */ | |
| .message.bot { | |
| background: #f9fafb !important; | |
| border: 1px solid #e5e7eb !important; | |
| } | |
| /* 左侧面板整体样式 */ | |
| .left-column { | |
| background: linear-gradient(to bottom, #ffffff 0%, #f9fafb 100%); | |
| border-radius: 12px; | |
| padding: 1rem; | |
| box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); | |
| } | |
| /* 按钮容器样式 */ | |
| .button-row { | |
| margin-top: 1rem; | |
| gap: 0.5rem; | |
| } | |
| /* Dark Mode Support */ | |
| .dark .message.bot { | |
| background: #1f2937 !important; | |
| border: 1px solid #374151 !important; | |
| color: #e5e7eb !important; | |
| } | |
| .dark .thinking-block { | |
| background: linear-gradient(135deg, #1f2937 0%, #111827 100%); | |
| border-left: 4px solid #4f46e5; | |
| } | |
| .dark .thinking-content { | |
| background: #111827; | |
| color: #e5e7eb; | |
| border: 1px solid #374151; | |
| } | |
| .dark .thinking-header { | |
| color: #818cf8; | |
| } | |
| .dark .left-panel { | |
| background: #111827; | |
| } | |
| .dark .left-column { | |
| background: linear-gradient(to bottom, #1f2937 0%, #111827 100%); | |
| } | |
| .dark .input-box textarea { | |
| background-color: #1f2937; | |
| border-color: #374151 !important; | |
| color: #e5e7eb; | |
| } | |
| .dark h3 { | |
| color: #e5e7eb; | |
| } | |
| /* 滚动条美化 */ | |
| ::-webkit-scrollbar { | |
| width: 8px; | |
| height: 8px; | |
| } | |
| ::-webkit-scrollbar-track { | |
| background: #f1f1f1; | |
| border-radius: 4px; | |
| } | |
| ::-webkit-scrollbar-thumb { | |
| background: #888; | |
| border-radius: 4px; | |
| } | |
| ::-webkit-scrollbar-thumb:hover { | |
| background: #555; | |
| } | |
| """ | |
| # Gradio Interface | |
| with gr.Blocks(title="Step Audio R1", css=custom_css, theme=gr.themes.Soft()) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div class="app-header"> | |
| <h1 style="color: white;">🔊 Step-Audio-R1</h1> | |
| <p style="color: white; margin: 0.8rem 0 0 0; opacity: 0.95; font-size: 1.15rem; font-weight: 500;"> | |
| Advanced Audio-Language Model with Reasoning | |
| </p> | |
| <p style="color: white; margin: 0.5rem 0 0 0; opacity: 0.85; font-size: 0.95rem;"> | |
| Comprehensive audio understanding: Speech, Sound, Music & Lyrics | |
| </p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # Left Panel - Input Area | |
| with gr.Column(scale=1, min_width=350): | |
| # Configuration | |
| with gr.Accordion("⚙️ Configuration", open=False): | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| lines=2, | |
| value="You are a voice assistant with extensive experience in audio processing.", | |
| placeholder="Enter system prompt...", | |
| elem_classes=["input-box"] | |
| ) | |
| max_tokens = gr.Slider( | |
| 1, 56000, | |
| value=16384, | |
| label="Max Tokens", | |
| info="Maximum tokens to generate" | |
| ) | |
| temperature = gr.Slider( | |
| 0.0, 2.0, | |
| value=0.7, | |
| label="Temperature", | |
| info="Higher = more random" | |
| ) | |
| top_p = gr.Slider( | |
| 0.0, 1.0, | |
| value=0.9, | |
| label="Top P", | |
| info="Nucleus sampling" | |
| ) | |
| show_thinking = gr.Checkbox( | |
| label="💭 Show Thinking Process", | |
| value=True, | |
| info="Display reasoning steps" | |
| ) | |
| # Input Area | |
| gr.Markdown("### 📝 Your Input") | |
| user_text = gr.Textbox( | |
| label="Text Message", | |
| lines=4, | |
| placeholder="Type your message here...", | |
| elem_classes=["input-box"], | |
| show_label=False | |
| ) | |
| audio_file = gr.Audio( | |
| label="🎤 Audio Input", | |
| type="filepath", | |
| sources=["microphone", "upload"], | |
| show_label=True | |
| ) | |
| # Audio size limit info | |
| audio_size_info = gr.Markdown( | |
| value=f"📊 Audio limit: {MAX_AUDIO_SIZE_MB}MB total available", | |
| elem_classes=["audio-size-info"] | |
| ) | |
| # Buttons | |
| with gr.Row(): | |
| clear_btn = gr.Button("🗑️ Clear", scale=1, size="lg") | |
| submit_btn = gr.Button( | |
| "🚀 Send", | |
| variant="primary", | |
| scale=2, | |
| size="lg", | |
| elem_classes=["primary-btn"] | |
| ) | |
| # Usage Guide at bottom | |
| with gr.Accordion("📖 Quick Guide", open=False): | |
| gr.Markdown(""" | |
| **Usage:** | |
| - Type text, upload audio, or both | |
| - Audio > 25s auto-splits | |
| - Toggle thinking process display | |
| **Tips:** | |
| - Thinking shown in blue gradient block | |
| - History auto-cleaned for API | |
| - Adjust params in Configuration | |
| """) | |
| # Right Panel - Conversation Area | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label="💬 Conversation", | |
| height=700, | |
| type="messages", | |
| elem_classes=["chatbot-container"], | |
| show_label=True, | |
| avatar_images=(None, None), | |
| bubble_full_width=False | |
| ) | |
| # State to track used audio size (in bytes) | |
| used_audio_size = gr.State(value=0) | |
| submit_btn.click( | |
| fn=chat, | |
| inputs=[system_prompt, user_text, audio_file, chatbot, used_audio_size, max_tokens, temperature, top_p, show_thinking], | |
| outputs=[chatbot, used_audio_size, audio_size_info] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ([], 0, "", None, f"📊 Audio limit: {MAX_AUDIO_SIZE_MB}MB total available"), | |
| outputs=[chatbot, used_audio_size, user_text, audio_file, audio_size_info] | |
| ) | |
| # Update audio size info when audio file changes | |
| audio_file.change( | |
| fn=lambda audio, used_size: get_audio_size_info(used_size, audio), | |
| inputs=[audio_file, used_audio_size], | |
| outputs=[audio_size_info] | |
| ) | |
| # Also listen to upload and stop_recording events | |
| audio_file.upload( | |
| fn=lambda audio, used_size: get_audio_size_info(used_size, audio), | |
| inputs=[audio_file, used_audio_size], | |
| outputs=[audio_size_info] | |
| ) | |
| audio_file.stop_recording( | |
| fn=lambda audio, used_size: get_audio_size_info(used_size, audio), | |
| inputs=[audio_file, used_audio_size], | |
| outputs=[audio_size_info] | |
| ) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--model", default=MODEL_NAME) | |
| args = parser.parse_args() | |
| # 更新全局模型名称 | |
| if args.model: | |
| MODEL_NAME = args.model | |
| print(f"启动Gradio: http://{args.host}:{args.port}") | |
| print(f"API地址: {API_BASE_URL}") | |
| print(f"模型: {MODEL_NAME}") | |
| demo.launch(server_name=args.host, server_port=args.port, share=False) | |