Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Gradio 多模态聊天界面:直接在 app.py 内部调用 vLLM.LLM 进行推理 | |
| """ | |
| import base64 | |
| import os | |
| import sys | |
| import threading | |
| import time | |
| from typing import Optional, Tuple | |
| import gradio as gr | |
| # 检查命令行参数,在导入 vllm 之前确定是否启用 | |
| # 这样可以在没有安装 vllm 的情况下运行界面预览 | |
| if "--no-vllm" in sys.argv: | |
| os.environ["ENABLE_VLLM"] = "false" | |
| # 检查是否启用 vLLM 模式 | |
| ENABLE_VLLM = os.getenv("ENABLE_VLLM", "true").lower() in ("true", "1", "yes") | |
| if ENABLE_VLLM: | |
| try: | |
| from vllm import LLM, SamplingParams | |
| except ImportError: | |
| print("[WARNING] 无法导入 vllm,自动切换到界面预览模式") | |
| print("[INFO] 如需使用 vLLM,请先安装: pip install vllm") | |
| ENABLE_VLLM = False | |
| LLM = None | |
| SamplingParams = None | |
| else: | |
| LLM = None | |
| SamplingParams = None | |
| print("[INFO] 运行在界面预览模式,不加载 vLLM") | |
| # 默认配置,可通过环境变量或 CLI 覆盖 | |
| DEFAULT_MODEL_ID = os.getenv("MODEL_NAME", "stepfun-ai/Step-Audio-2-mini-Think") | |
| DEFAULT_MODEL_PATH = os.getenv("MODEL_PATH", DEFAULT_MODEL_ID) | |
| DEFAULT_TP = int(os.getenv("TENSOR_PARALLEL_SIZE", "4")) | |
| DEFAULT_MAX_MODEL_LEN = int(os.getenv("MAX_MODEL_LEN", "8192")) | |
| DEFAULT_GPU_UTIL = float(os.getenv("GPU_MEMORY_UTILIZATION", "0.9")) | |
| DEFAULT_TOKENIZER_MODE = os.getenv("TOKENIZER_MODE", "step_audio_2") | |
| DEFAULT_SERVED_NAME = os.getenv("SERVED_MODEL_NAME", "step-audio-2-mini-think") | |
| _llm: Optional[LLM] = None | |
| _llm_lock = threading.Lock() | |
| LLM_ARGS = { | |
| "model": DEFAULT_MODEL_PATH, | |
| "trust_remote_code": True, | |
| "tensor_parallel_size": DEFAULT_TP, | |
| "tokenizer_mode": DEFAULT_TOKENIZER_MODE, | |
| "max_model_len": DEFAULT_MAX_MODEL_LEN, | |
| "served_model_name": DEFAULT_SERVED_NAME, | |
| "gpu_memory_utilization": DEFAULT_GPU_UTIL, | |
| } | |
| def encode_audio_to_base64(audio_path: Optional[str]) -> Optional[dict]: | |
| """将音频文件编码为 base64""" | |
| if audio_path is None: | |
| return None | |
| try: | |
| with open(audio_path, "rb") as audio_file: | |
| audio_data = audio_file.read() | |
| audio_base64 = base64.b64encode(audio_data).decode('utf-8') | |
| # 尝试从文件扩展名推断格式 | |
| ext = os.path.splitext(audio_path)[1].lower().lstrip('.') | |
| if not ext: | |
| ext = "wav" # 默认格式 | |
| return { | |
| "data": audio_base64, | |
| "format": ext | |
| } | |
| except Exception as e: | |
| print(f"Error encoding audio: {e}") | |
| return None | |
| def format_messages( | |
| system_prompt: str, | |
| chat_history: list, | |
| user_text: str, | |
| audio_file: Optional[str] | |
| ) -> list: | |
| """格式化消息为 OpenAI API 格式""" | |
| messages = [] | |
| # 添加 system prompt | |
| if system_prompt and system_prompt.strip(): | |
| messages.append({ | |
| "role": "system", | |
| "content": system_prompt.strip() | |
| }) | |
| # 添加历史对话 | |
| for human, assistant in chat_history: | |
| if human: | |
| messages.append({"role": "user", "content": human}) | |
| if assistant: | |
| messages.append({"role": "assistant", "content": assistant}) | |
| # 添加当前用户输入 | |
| content_parts = [] | |
| # 添加文本输入 | |
| if user_text and user_text.strip(): | |
| content_parts.append({ | |
| "type": "text", | |
| "text": user_text.strip() | |
| }) | |
| # 添加音频输入 | |
| if audio_file: | |
| audio_data = encode_audio_to_base64(audio_file) | |
| if audio_data: | |
| content_parts.append({ | |
| "type": "input_audio", | |
| "input_audio": audio_data | |
| }) | |
| if content_parts: | |
| # 如果只有一个文本部分,直接使用字符串 | |
| if len(content_parts) == 1 and content_parts[0]["type"] == "text": | |
| messages.append({ | |
| "role": "user", | |
| "content": content_parts[0]["text"] | |
| }) | |
| else: | |
| messages.append({ | |
| "role": "user", | |
| "content": content_parts | |
| }) | |
| return messages | |
| def chat_predict( | |
| system_prompt: str, | |
| user_text: str, | |
| audio_file: Optional[str], | |
| chat_history: list, | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float | |
| ) -> Tuple[list, str]: | |
| """调用本地 vLLM LLM 完成推理""" | |
| if not user_text and not audio_file: | |
| return chat_history, "⚠ 请提供文本或音频输入" | |
| # 如果是预览模式,返回模拟响应 | |
| if not ENABLE_VLLM: | |
| user_display = user_text if user_text else "[音频输入]" | |
| mock_response = f"这是一个模拟回复。您说: {user_text[:50] if user_text else '音频'}" | |
| chat_history.append((user_display, mock_response)) | |
| return chat_history, "" | |
| messages = format_messages(system_prompt, chat_history, user_text, audio_file) | |
| if not messages: | |
| return chat_history, "⚠ 无有效输入" | |
| try: | |
| llm = _get_llm() | |
| sampling_params = SamplingParams( | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| start_time = time.time() | |
| outputs = llm.chat(messages, sampling_params=sampling_params, use_tqdm=False) | |
| latency = time.time() - start_time | |
| if not outputs or not outputs[0].outputs: | |
| return chat_history, "⚠ 模型未返回结果" | |
| assistant_message = outputs[0].outputs[0].text | |
| user_display = user_text if user_text else "[音频输入]" | |
| chat_history.append((user_display, assistant_message)) | |
| return chat_history, "" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return chat_history, "" | |
| def _get_llm() -> LLM: | |
| """单例方式初始化 LLM""" | |
| if not ENABLE_VLLM: | |
| raise RuntimeError("vLLM 未启用,无法加载模型") | |
| global _llm | |
| if _llm is not None: | |
| return _llm | |
| with _llm_lock: | |
| if _llm is not None: | |
| return _llm | |
| print(f"[LLM] 初始化中,参数: {LLM_ARGS}") | |
| _llm = LLM(**LLM_ARGS) | |
| return _llm | |
| def _set_llm_args(**kwargs) -> None: | |
| """更新 LLM 初始化参数""" | |
| global LLM_ARGS, _llm | |
| LLM_ARGS = kwargs | |
| _llm = None # 确保使用新配置重新加载 | |
| # 构建 Gradio 界面 | |
| with gr.Blocks(title="Step Audio 2 Chat", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Step Audio R1 Demo | |
| """ | |
| ) | |
| with gr.Row(): | |
| # 左侧:参数配置 | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 配置") | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| placeholder="输入系统提示词...", | |
| lines=4, | |
| value="You are an expert in audio analysis, please analyze the audio content and answer the questions accurately" | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider( | |
| label="Max Tokens", | |
| minimum=1, | |
| maximum=16384, | |
| value=8192, | |
| step=1 | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1 | |
| ) | |
| top_p = gr.Slider( | |
| label="Top P", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05 | |
| ) | |
| # 右侧:对话和输入 | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 对话") | |
| chatbot = gr.Chatbot( | |
| label="聊天历史", | |
| height=400, | |
| show_copy_button=True | |
| ) | |
| user_text = gr.Textbox( | |
| label="文本输入", | |
| placeholder="输入您的消息...", | |
| lines=2 | |
| ) | |
| audio_file = gr.Audio( | |
| label="音频输入", | |
| type="filepath", | |
| sources=["microphone", "upload"] | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("提交", variant="primary", size="lg") | |
| clear_btn = gr.Button("清空", variant="secondary") | |
| status_text = gr.Textbox(label="状态", interactive=False, visible=False) | |
| # 事件绑定 | |
| submit_btn.click( | |
| fn=chat_predict, | |
| inputs=[ | |
| system_prompt, | |
| user_text, | |
| audio_file, | |
| chatbot, | |
| max_tokens, | |
| temperature, | |
| top_p | |
| ], | |
| outputs=[chatbot, status_text] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ([], "", None), | |
| outputs=[chatbot, user_text, audio_file] | |
| ) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Step Audio 2 Gradio Chat Interface") | |
| parser.add_argument( | |
| "--host", | |
| type=str, | |
| default="0.0.0.0", | |
| help="服务器主机地址" | |
| ) | |
| parser.add_argument( | |
| "--port", | |
| type=int, | |
| default=7860, | |
| help="服务器端口" | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| type=str, | |
| default=DEFAULT_MODEL_PATH, | |
| help="模型名称或本地路径" | |
| ) | |
| parser.add_argument( | |
| "--tensor-parallel-size", | |
| type=int, | |
| default=DEFAULT_TP, | |
| help="张量并行数量" | |
| ) | |
| parser.add_argument( | |
| "--max-model-len", | |
| type=int, | |
| default=DEFAULT_MAX_MODEL_LEN, | |
| help="最大上下文长度" | |
| ) | |
| parser.add_argument( | |
| "--gpu-memory-utilization", | |
| type=float, | |
| default=DEFAULT_GPU_UTIL, | |
| help="GPU 显存利用率" | |
| ) | |
| parser.add_argument( | |
| "--tokenizer-mode", | |
| type=str, | |
| default=DEFAULT_TOKENIZER_MODE, | |
| help="tokenizer 模式" | |
| ) | |
| parser.add_argument( | |
| "--served-model-name", | |
| type=str, | |
| default=DEFAULT_SERVED_NAME, | |
| help="对外暴露的模型名称" | |
| ) | |
| parser.add_argument( | |
| "--no-vllm", | |
| action="store_true", | |
| help="禁用 vLLM,仅启动界面预览模式" | |
| ) | |
| args = parser.parse_args() | |
| # --no-vllm 参数已在文件开头处理,这里只是提示 | |
| if args.no_vllm and not ENABLE_VLLM: | |
| print("[INFO] 已禁用 vLLM,运行在界面预览模式") | |
| _set_llm_args( | |
| model=args.model, | |
| trust_remote_code=True, | |
| tensor_parallel_size=args.tensor_parallel_size, | |
| tokenizer_mode=args.tokenizer_mode, | |
| max_model_len=args.max_model_len, | |
| served_model_name=args.served_model_name, | |
| gpu_memory_utilization=args.gpu_memory_utilization, | |
| ) | |
| print("==========================================") | |
| print("Step Audio 2 Gradio Chat") | |
| if ENABLE_VLLM: | |
| print(f"模式: vLLM 推理模式") | |
| print(f"模型: {args.model}") | |
| print(f"Tensor Parallel Size: {args.tensor_parallel_size}") | |
| print(f"Max Model Len: {args.max_model_len}") | |
| print(f"Tokenizer Mode: {args.tokenizer_mode}") | |
| print(f"Served Model Name: {args.served_model_name}") | |
| else: | |
| print(f"模式: 界面预览模式(无 vLLM)") | |
| print(f"Gradio 地址: http://{args.host}:{args.port}") | |
| print("==========================================") | |
| demo.queue().launch( | |
| server_name=args.host, | |
| server_port=args.port, | |
| share=False | |
| ) | |