File size: 18,051 Bytes
f833e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
import gradio as gr
import requests
import json
import os
import warnings
from huggingface_hub import InferenceClient

# 抑制 asyncio 警告
warnings.filterwarnings('ignore', category=DeprecationWarning)
os.environ['PYTHONWARNINGS'] = 'ignore'

# 如果在 GPU 环境但不需要 GPU,禁用 CUDA
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
    os.environ['CUDA_VISIBLE_DEVICES'] = ''

# ========== MCP 工具简化定义(符合MCP协议标准) ==========
MCP_TOOLS = [
    {"type": "function", "function": {"name": "advanced_search_company", "description": "Search US companies", "parameters": {"type": "object", "properties": {"company_input": {"type": "string"}}, "required": ["company_input"]}}},
    {"type": "function", "function": {"name": "get_latest_financial_data", "description": "Get latest financial data", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}}, "required": ["cik"]}}},
    {"type": "function", "function": {"name": "extract_financial_metrics", "description": "Get multi-year trends", "parameters": {"type": "object", "properties": {"cik": {"type": "string"}, "years": {"type": "integer"}}, "required": ["cik", "years"]}}},
    {"type": "function", "function": {"name": "get_quote", "description": "Get stock quote", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}}, "required": ["symbol"]}}},
    {"type": "function", "function": {"name": "get_market_news", "description": "Get market news", "parameters": {"type": "object", "properties": {"category": {"type": "string"}}, "required": ["category"]}}},
    {"type": "function", "function": {"name": "get_company_news", "description": "Get company news", "parameters": {"type": "object", "properties": {"symbol": {"type": "string"}, "from_date": {"type": "string"}, "to_date": {"type": "string"}}, "required": ["symbol"]}}}
]

# ========== MCP 服务配置 ==========
MCP_SERVICES = {
    "financial": {"url": "http://localhost:7861/mcp", "type": "fastmcp"},
    "market": {"url": "https://jc321-marketandstockmcp.hf.space", "type": "gradio"}
}

TOOL_ROUTING = {
    "advanced_search_company": MCP_SERVICES["financial"],
    "get_latest_financial_data": MCP_SERVICES["financial"],
    "extract_financial_metrics": MCP_SERVICES["financial"],
    "get_quote": MCP_SERVICES["market"],
    "get_market_news": MCP_SERVICES["market"],
    "get_company_news": MCP_SERVICES["market"]
}

# ========== 初始化 LLM 客户端 ==========
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
client = InferenceClient(api_key=hf_token) if hf_token else InferenceClient()
print(f"✅ LLM initialized: Qwen/Qwen3-32B:groq")
print(f"📊 MCP Services: {len(MCP_SERVICES)} services, {len(MCP_TOOLS)} tools")

# ========== Token 限制配置 ==========
# HuggingFace Inference API 实际限制约 8000-16000 tokens
# 为了安全,设置更低的限制
MAX_TOTAL_TOKENS = 6000  # 总上下文限制
MAX_TOOL_RESULT_CHARS = 1500  # 工具返回最大字符数 (增加到1500)
MAX_HISTORY_CHARS = 500  # 单条历史消息最大字符数
MAX_HISTORY_TURNS = 2  # 最大历史轮数
MAX_TOOL_ITERATIONS = 6  # 最大工具调用轮数 (增加到6,支持多工具调用)
MAX_OUTPUT_TOKENS = 2000  # 最大输出 tokens (增加到2000)

def estimate_tokens(text):
    """估算文本 token 数量(粗略:1 token ≈ 2 字符)"""
    return len(str(text)) // 2

def truncate_text(text, max_chars, suffix="...[truncated]"):
    """截断文本到指定长度"""
    text = str(text)
    if len(text) <= max_chars:
        return text
    return text[:max_chars] + suffix

def get_system_prompt():
    """生成包含当前日期的系统提示词(精简版)"""
    from datetime import datetime
    current_date = datetime.now().strftime("%Y-%m-%d")
    return f"""Financial analyst. Today: {current_date}. Use tools for company data, stock prices, news. Be concise."""

# ============================================================
# MCP 服务调用核心代码区
# 支持 FastMCP (JSON-RPC) 和 Gradio (SSE) 两种协议
# ============================================================

def call_mcp_tool(tool_name, arguments):
    """调用 MCP 工具"""
    service_config = TOOL_ROUTING.get(tool_name)
    if not service_config:
        return {"error": f"Unknown tool: {tool_name}"}
    
    try:
        if service_config["type"] == "fastmcp":
            return _call_fastmcp(service_config["url"], tool_name, arguments)
        elif service_config["type"] == "gradio":
            return _call_gradio_api(service_config["url"], tool_name, arguments)
        else:
            return {"error": "Unknown service type"}
    except Exception as e:
        return {"error": str(e)}


def _call_fastmcp(service_url, tool_name, arguments):
    """FastMCP: 标准 MCP JSON-RPC"""
    response = requests.post(
        service_url,
        json={"jsonrpc": "2.0", "method": "tools/call", "params": {"name": tool_name, "arguments": arguments}, "id": 1},
        headers={"Content-Type": "application/json"},
        timeout=30
    )
    
    if response.status_code != 200:
        return {"error": f"HTTP {response.status_code}"}
    
    data = response.json()
    
    # 解包 MCP 协议: jsonrpc -> result -> content[0].text -> JSON
    if isinstance(data, dict) and "result" in data:
        result = data["result"]
        if isinstance(result, dict) and "content" in result:
            content = result["content"]
            if isinstance(content, list) and len(content) > 0:
                first_item = content[0]
                if isinstance(first_item, dict) and "text" in first_item:
                    try:
                        return json.loads(first_item["text"])
                    except (json.JSONDecodeError, TypeError):
                        return {"text": first_item["text"]}
        return result
    return data


def _call_gradio_api(service_url, tool_name, arguments):
    """Gradio: SSE 流式协议"""
    tool_map = {"get_quote": "test_quote_tool", "get_market_news": "test_market_news_tool", "get_company_news": "test_company_news_tool"}
    gradio_fn = tool_map.get(tool_name)
    if not gradio_fn:
        return {"error": "No mapping"}
    
    # 构造参数
    if tool_name == "get_quote":
        params = [arguments.get("symbol", "")]
    elif tool_name == "get_market_news":
        params = [arguments.get("category", "general")]
    elif tool_name == "get_company_news":
        params = [arguments.get("symbol", ""), arguments.get("from_date", ""), arguments.get("to_date", "")]
    else:
        params = []
    
    # 提交请求
    call_url = f"{service_url}/call/{gradio_fn}"
    resp = requests.post(call_url, json={"data": params}, timeout=10)
    if resp.status_code != 200:
        return {"error": f"HTTP {resp.status_code}"}
    
    event_id = resp.json().get("event_id")
    if not event_id:
        return {"error": "No event_id"}
    
    # 获取结果 (SSE)
    result_resp = requests.get(f"{call_url}/{event_id}", stream=True, timeout=20)
    if result_resp.status_code != 200:
        return {"error": f"HTTP {result_resp.status_code}"}
    
    # 解析 SSE
    for line in result_resp.iter_lines():
        if line and line.decode('utf-8').startswith('data: '):
            try:
                result_data = json.loads(line.decode('utf-8')[6:])
                if isinstance(result_data, list) and len(result_data) > 0:
                    return {"text": result_data[0]}
            except json.JSONDecodeError:
                continue
    
    return {"error": "No result"}

# ============================================================
# End of MCP 服务调用代码区
# ============================================================

def chatbot_response(message, history):
    """AI 助手主函数(流式输出,性能优化)"""
    try:
        messages = [{"role": "system", "content": get_system_prompt()}]
        
        # 添加历史(最近2轮) - 严格限制上下文长度
        if history:
            for item in history[-MAX_HISTORY_TURNS:]:
                if isinstance(item, (list, tuple)) and len(item) == 2:
                    # 用户消息(不截断)
                    messages.append({"role": "user", "content": item[0]})
                    
                    # 助手回复(严格截断)
                    assistant_msg = str(item[1])
                    if len(assistant_msg) > MAX_HISTORY_CHARS:
                        assistant_msg = truncate_text(assistant_msg, MAX_HISTORY_CHARS)
                    messages.append({"role": "assistant", "content": assistant_msg})
        
        messages.append({"role": "user", "content": message})
        
        tool_calls_log = []
        
        # LLM 调用循环(支持多轮工具调用)
        final_response_content = None
        for iteration in range(MAX_TOOL_ITERATIONS):
            response = client.chat.completions.create(
                model="Qwen/Qwen3-32B:groq",
                messages=messages,
                tools=MCP_TOOLS,
                max_tokens=MAX_OUTPUT_TOKENS,
                temperature=0.5,
                tool_choice="auto",
                stream=False
            )
            
            choice = response.choices[0]
            
            if choice.message.tool_calls:
                messages.append(choice.message)
                
                for tool_call in choice.message.tool_calls:
                    tool_name = tool_call.function.name
                    try:
                        tool_args = json.loads(tool_call.function.arguments)
                    except json.JSONDecodeError:
                        tool_args = {}
                    
                    # 调用 MCP 工具
                    tool_result = call_mcp_tool(tool_name, tool_args)
                    
                    # 检查错误
                    if isinstance(tool_result, dict) and "error" in tool_result:
                        # 工具调用失败,记录错误
                        tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result, "error": True})
                        result_for_llm = json.dumps({"error": tool_result.get("error", "Unknown error")}, ensure_ascii=False)
                    else:
                        # 限制返回结果大小
                        result_str = json.dumps(tool_result, ensure_ascii=False)
                        
                        if len(result_str) > MAX_TOOL_RESULT_CHARS:
                            if isinstance(tool_result, dict) and "text" in tool_result:
                                truncated_text = truncate_text(tool_result["text"], MAX_TOOL_RESULT_CHARS - 50)
                                tool_result_truncated = {"text": truncated_text, "_truncated": True}
                            elif isinstance(tool_result, dict):
                                truncated = {}
                                char_count = 0
                                for k, v in list(tool_result.items())[:8]:  # 保留前8个字段
                                    v_str = str(v)[:300]  # 每个值最多300字符
                                    truncated[k] = v_str
                                    char_count += len(k) + len(v_str)
                                    if char_count > MAX_TOOL_RESULT_CHARS:
                                        break
                                tool_result_truncated = {**truncated, "_truncated": True}
                            else:
                                tool_result_truncated = {"preview": truncate_text(result_str, MAX_TOOL_RESULT_CHARS), "_truncated": True}
                            result_for_llm = json.dumps(tool_result_truncated, ensure_ascii=False)
                        else:
                            result_for_llm = result_str
                        
                        # 记录成功的工具调用
                        tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result})
                    
                    messages.append({
                        "role": "tool",
                        "name": tool_name,
                        "content": result_for_llm,
                        "tool_call_id": tool_call.id
                    })
                
                continue
            else:
                # 没有更多工具调用,保存最终答案
                final_response_content = choice.message.content
                break
        
        # 构建响应前缀(简化版)
        response_prefix = ""
        
        # 显示工具调用(使用原生HTML details标签)
        if tool_calls_log:
            response_prefix += """<div style='margin-bottom: 15px;'>
<div style='background: #f0f0f0; padding: 8px 12px; border-radius: 6px; font-weight: 600; color: #333;'>
🛠️ Tools Used ({} calls)
</div>
""".format(len(tool_calls_log))
            
            for idx, tool_call in enumerate(tool_calls_log):
                # 预先计算 JSON 字符串,避免重复调用
                args_json = json.dumps(tool_call['arguments'], ensure_ascii=False)
                result_json = json.dumps(tool_call.get('result', {}), ensure_ascii=False, indent=2)
                result_preview = result_json[:1500] + ('...' if len(result_json) > 1500 else '')
                
                # 显示错误状态
                error_indicator = " ❌ Error" if tool_call.get('error') else ""
                
                # 使用原生 HTML5 details/summary 标签(不需要 JavaScript)
                response_prefix += f"""<details style='margin: 8px 0; border: 1px solid #ddd; border-radius: 6px; overflow: hidden;'>
  <summary style='background: #fff; padding: 10px; cursor: pointer; user-select: none; list-style: none;'>
    <div style='display: flex; justify-content: space-between; align-items: center;'>
      <div style='flex: 1;'>
        <strong style='color: #2c5aa0;'>📌 {idx+1}. {tool_call['name']}{error_indicator}</strong>
        <div style='font-size: 0.85em; color: #666; margin-top: 4px;'>📥 Input: <code style='background: #f5f5f5; padding: 2px 6px; border-radius: 3px;'>{args_json}</code></div>
      </div>
      <span style='font-size: 1.2em; color: #999; margin-left: 10px;'>▶</span>
    </div>
  </summary>
  <div style='background: #f9f9f9; padding: 12px; border-top: 1px solid #eee;'>
    <div style='font-size: 0.9em; color: #333;'>
      <strong>📤 Output:</strong>
      <pre style='background: #fff; padding: 10px; border-radius: 4px; overflow-x: auto; margin-top: 6px; font-size: 0.85em; border: 1px solid #e0e0e0; max-height: 400px; white-space: pre-wrap;'>{result_preview}</pre>
    </div>
  </div>
</details>
"""
            
            response_prefix += """</div>

---

"""
            response_prefix += "\n"
        
        # 流式输出最终答案
        yield response_prefix
        
        # 如果已经有最终答案,直接输出
        if final_response_content:
            # 已经从循环中获得了最终答案,直接输出
            yield response_prefix + final_response_content
        else:
            # 如果循环结束但没有最终答案(达到最大迭代次数),需要再调用一次让模型总结
            try:
                stream = client.chat.completions.create(
                    model="Qwen/Qwen3-32B:groq",
                    messages=messages,
                    tools=None,  # 不再允许调用工具
                    max_tokens=MAX_OUTPUT_TOKENS,
                    temperature=0.5,
                    stream=True
                )
                
                accumulated_text = ""
                for chunk in stream:
                    if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content:
                        accumulated_text += chunk.choices[0].delta.content
                        yield response_prefix + accumulated_text
            except Exception as stream_error:
                # 流式输出失败,尝试非流式
                final_resp = client.chat.completions.create(
                    model="Qwen/Qwen3-32B:groq",
                    messages=messages,
                    tools=None,
                    max_tokens=MAX_OUTPUT_TOKENS,
                    temperature=0.5,
                    stream=False
                )
                yield response_prefix + final_resp.choices[0].message.content
        
    except Exception as e:
        import traceback
        error_detail = str(e)
        if "500" in error_detail:
            yield f"❌ Error: 模型服务器错误。可能是数据太大或请求超时。\n\n详细信息: {error_detail[:200]}"
        else:
            yield f"❌ Error: {error_detail}\n\n{traceback.format_exc()[:500]}"

# ========== Gradio 界面(极简版)==========
with gr.Blocks(title="Financial AI Assistant") as demo:
    gr.Markdown("# 💬 Financial AI Assistant")
    
    chat = gr.ChatInterface(
        fn=chatbot_response,
        examples=[
            "What's Apple's latest revenue and profit?",
            "Show me NVIDIA's 3-year financial trends",
            "How is Tesla's stock performing today?",
            "Get the latest market news about crypto",
            "Compare Microsoft's latest earnings with its current stock price",
        ],
        chatbot=gr.Chatbot(height=700),
        textbox=gr.Textbox(lines=4, placeholder="Ask me anything about finance, stocks, or company data...", show_label=False),
    )

# 启动应用
if __name__ == "__main__":
    import sys
    
    # 修复 asyncio 事件循环问题
    if sys.platform == 'linux':
        try:
            import asyncio
            asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
        except:
            pass
    
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True,
        ssr_mode=False,
        quiet=False
    )