|
|
import json |
|
|
import time |
|
|
from .mcp_client import MCP_Client |
|
|
|
|
|
BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" |
|
|
MODEL = "qwen3-max" |
|
|
TOKEN = "sk-ef26097310ec45c184e8d84b31ea9356" |
|
|
|
|
|
|
|
|
MCP_CONFIG = { |
|
|
"mcpServers": { |
|
|
"alpha-vantage": { |
|
|
"command": "npx", |
|
|
"args": ["alpha-ventage-mcp"], |
|
|
"env": { |
|
|
"ALPHA_VANTAGE_API_KEY": "97Q9TT7I6J9ZOLDS" |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
def extract_stock_symbol(user_input: str) -> str: |
|
|
""" |
|
|
从用户输入中提取股票代码 |
|
|
支持多种格式: |
|
|
- "查询阿里巴巴股票" -> BABA |
|
|
- "查询AAPL股票" -> AAPL |
|
|
- "我想了解MSFT的股价" -> MSFT |
|
|
- "BABA股价如何" -> BABA |
|
|
""" |
|
|
|
|
|
stock_mapping = { |
|
|
"阿里巴巴": "BABA", |
|
|
"阿里": "BABA", |
|
|
"腾讯": "TCEHY", |
|
|
"微软": "MSFT", |
|
|
"苹果": "AAPL", |
|
|
"谷歌": "GOOGL", |
|
|
"亚马逊": "AMZN", |
|
|
"特斯拉": "TSLA", |
|
|
"英伟达": "NVDA", |
|
|
"百度": "BIDU", |
|
|
"京东": "JD", |
|
|
"拼多多": "PDD" |
|
|
} |
|
|
|
|
|
|
|
|
for chinese_name, symbol in stock_mapping.items(): |
|
|
if chinese_name in user_input: |
|
|
return symbol |
|
|
|
|
|
|
|
|
import re |
|
|
|
|
|
patterns = [ |
|
|
r'\b([A-Z]{1,5})\b', |
|
|
r'股票代码[是为]?([A-Z]{1,5})', |
|
|
r'([A-Z]{1,5})股票' |
|
|
] |
|
|
|
|
|
for pattern in patterns: |
|
|
match = re.search(pattern, user_input) |
|
|
if match: |
|
|
symbol = match.group(1) |
|
|
|
|
|
if 1 <= len(symbol) <= 5 and symbol.isalpha(): |
|
|
return symbol |
|
|
|
|
|
|
|
|
return "BABA" |
|
|
|
|
|
|
|
|
def get_stock_price(symbol: str): |
|
|
"""获取股票价格""" |
|
|
client = MCP_Client(MCP_CONFIG) |
|
|
if not client.initialize(): |
|
|
raise Exception("MCP初始化失败") |
|
|
|
|
|
try: |
|
|
result = client.call_tool("get_stock_price", {"symbol": symbol}) |
|
|
client.close() |
|
|
return result |
|
|
except Exception as e: |
|
|
client.close() |
|
|
raise e |
|
|
|
|
|
|
|
|
def get_company_overview(symbol: str): |
|
|
"""获取公司概况""" |
|
|
client = MCP_Client(MCP_CONFIG) |
|
|
if not client.initialize(): |
|
|
raise Exception("MCP初始化失败") |
|
|
|
|
|
try: |
|
|
result = client.call_tool("get_company_overview", {"symbol": symbol}) |
|
|
client.close() |
|
|
return result |
|
|
except Exception as e: |
|
|
client.close() |
|
|
raise e |
|
|
|
|
|
|
|
|
def get_daily_time_series(symbol: str): |
|
|
"""获取每日时间序列数据""" |
|
|
client = MCP_Client(MCP_CONFIG) |
|
|
if not client.initialize(): |
|
|
raise Exception("MCP初始化失败") |
|
|
|
|
|
try: |
|
|
result = client.call_tool("get_daily_time_series", {"symbol": symbol}) |
|
|
client.close() |
|
|
return result |
|
|
except Exception as e: |
|
|
client.close() |
|
|
raise e |
|
|
|
|
|
|
|
|
def get_forex_rate(from_currency: str = "USD", to_currency: str = "CNY"): |
|
|
"""获取外汇汇率""" |
|
|
client = MCP_Client(MCP_CONFIG) |
|
|
if not client.initialize(): |
|
|
raise Exception("MCP初始化失败") |
|
|
|
|
|
try: |
|
|
result = client.call_tool("get_forex_rate", { |
|
|
"from_currency": from_currency, |
|
|
"to_currency": to_currency |
|
|
}) |
|
|
client.close() |
|
|
return result |
|
|
except Exception as e: |
|
|
client.close() |
|
|
raise e |
|
|
|
|
|
|
|
|
def get_tools_for_stock(symbol: str): |
|
|
"""根据股票代码生成工具列表""" |
|
|
return [ |
|
|
{ |
|
|
"name": "get_stock_price", |
|
|
"description": "获取股票价格", |
|
|
"params": {"symbol": symbol}, |
|
|
}, |
|
|
{ |
|
|
"name": "get_company_overview", |
|
|
"description": "获取公司概况", |
|
|
"params": {"symbol": symbol}, |
|
|
}, |
|
|
{ |
|
|
"name": "get_daily_time_series", |
|
|
"description": "获取每日时间序列数据", |
|
|
"params": {"symbol": symbol}, |
|
|
}, |
|
|
{ |
|
|
"name": "get_forex_rate", |
|
|
"description": "获取外汇汇率", |
|
|
"params": {"from_currency": "USD", "to_currency": "CNY"}, |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
def call_llm(prompt: str): |
|
|
"""模拟调用大模型的函数(你可以替换为真实请求)""" |
|
|
import requests |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {TOKEN}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
data = { |
|
|
"model": MODEL, |
|
|
"messages": [{"role": "user", "content": prompt}], |
|
|
"stream": True |
|
|
} |
|
|
response = requests.post(f"{BASE_URL}/chat/completions", headers=headers, json=data, stream=True) |
|
|
for line in response.iter_lines(): |
|
|
if line: |
|
|
yield line.decode('utf-8') |
|
|
|
|
|
|
|
|
def get_stock_data_example(symbol: str): |
|
|
"""示例函数:获取股票数据的简化调用方式""" |
|
|
try: |
|
|
|
|
|
price_data = get_stock_price(symbol) |
|
|
print(f"股票价格数据: {price_data}") |
|
|
|
|
|
|
|
|
overview_data = get_company_overview(symbol) |
|
|
print(f"公司概况数据: {overview_data}") |
|
|
|
|
|
|
|
|
time_series_data = get_daily_time_series(symbol) |
|
|
print(f"时间序列数据: {time_series_data}") |
|
|
|
|
|
|
|
|
forex_data = get_forex_rate() |
|
|
print(f"外汇汇率数据: {forex_data}") |
|
|
|
|
|
return { |
|
|
"price": price_data, |
|
|
"overview": overview_data, |
|
|
"time_series": time_series_data, |
|
|
"forex": forex_data |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"获取股票数据时出错: {str(e)}") |
|
|
raise e |
|
|
|
|
|
|
|
|
def process_tool_analysis(user_input: str, history: list): |
|
|
""" |
|
|
处理工具分析的主要逻辑 |
|
|
- user_input: str |
|
|
- history: list of {"role": ..., "content": ...} |
|
|
Returns generator for streaming. |
|
|
""" |
|
|
if not user_input.strip(): |
|
|
yield "", history |
|
|
return |
|
|
|
|
|
|
|
|
history.append({"role": "user", "content": user_input}) |
|
|
|
|
|
try: |
|
|
|
|
|
stock_symbol = extract_stock_symbol(user_input) |
|
|
|
|
|
|
|
|
tools_to_test = get_tools_for_stock(stock_symbol) |
|
|
|
|
|
|
|
|
client = MCP_Client(MCP_CONFIG) |
|
|
|
|
|
|
|
|
if not client.initialize(): |
|
|
error_msg = "❌ MCP初始化失败" |
|
|
history.append({"role": "assistant", "content": error_msg}) |
|
|
yield "", history |
|
|
return |
|
|
|
|
|
|
|
|
all_results = [] |
|
|
|
|
|
|
|
|
for i, tool in enumerate(tools_to_test, 1): |
|
|
|
|
|
tool_content = f'''<details> |
|
|
<summary class="tool-header querying" style="background-color: #f0f8ff; border: 1px solid #d0e0f0; border-radius: 5px; padding: 10px; margin: 10px 0; color: #1e40af; font-weight: bold; cursor: pointer;"> |
|
|
[MCP] <span class="status-tag querying">查询中</span> 🔧 工具 {i}/{len(tools_to_test)}: {tool['name']} ({tool['description']}) |
|
|
</summary> |
|
|
<div style="border: 1px solid #d0e0f0; border-top: none; border-radius: 0 0 5px 5px; padding: 15px; background-color: #f8fafc; margin-bottom: 10px;">''' |
|
|
|
|
|
history.append({"role": "assistant", "content": tool_content}) |
|
|
yield "", history |
|
|
|
|
|
try: |
|
|
result = client.call_tool(tool["name"], tool["params"]) |
|
|
|
|
|
|
|
|
if isinstance(result, dict) and "content" in result: |
|
|
content = result["content"] |
|
|
if isinstance(content, list) and len(content) > 0: |
|
|
text_content = content[0].get("text", "") if isinstance(content[0], dict) else str(content[0]) |
|
|
|
|
|
summary_prompt = f"给我总结这个信息内容:{text_content}" |
|
|
|
|
|
|
|
|
start_idx = history[-1]["content"].find('<summary') |
|
|
end_idx = history[-1]["content"].find('</summary>') + len('</summary>') |
|
|
if start_idx != -1 and end_idx != -1: |
|
|
summary_content = history[-1]["content"][start_idx:end_idx] |
|
|
|
|
|
updated_summary = summary_content.replace('tool-header querying', |
|
|
'tool-header analyzing').replace( |
|
|
'status-tag querying', 'status-tag analyzing').replace('查询中', '分析中') |
|
|
history[-1]["content"] = history[-1]["content"].replace(summary_content, updated_summary) |
|
|
yield "", history |
|
|
|
|
|
|
|
|
model_summary = "" |
|
|
|
|
|
history[-1]["content"] += "<div id='analysis-result'>" |
|
|
yield "", history |
|
|
|
|
|
for chunk in call_llm(summary_prompt): |
|
|
if not chunk or chunk == "[DONE]": |
|
|
continue |
|
|
|
|
|
if chunk.startswith("data: "): |
|
|
chunk = chunk[6:] |
|
|
|
|
|
try: |
|
|
response_data = json.loads(chunk) |
|
|
delta = response_data.get("choices", [{}])[0].get("delta", {}) |
|
|
content = delta.get("content", "") |
|
|
if content: |
|
|
model_summary += content |
|
|
|
|
|
history[-1]["content"] = history[-1]["content"][0:history[-1]["content"].find( |
|
|
"<div id='analysis-result'>") + 26] + model_summary |
|
|
yield "", history |
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
|
|
|
history[-1]["content"] += "</div>" |
|
|
|
|
|
result_msg = f"✅ {tool['name']}: {model_summary}" |
|
|
all_results.append(f"{tool['description']}: {model_summary}") |
|
|
|
|
|
history[-1]["content"] += f"<p><strong>分析结果:</strong></p><p>{result_msg}</p>" |
|
|
|
|
|
|
|
|
|
|
|
start_idx = history[-1]["content"].find('<summary') |
|
|
end_idx = history[-1]["content"].find('</summary>') + len('</summary>') |
|
|
if start_idx != -1 and end_idx != -1: |
|
|
summary_start = history[-1]["content"][start_idx:end_idx] |
|
|
updated_summary = summary_start.replace('tool-header analyzing', |
|
|
'tool-header completed').replace( |
|
|
'status-tag analyzing', 'status-tag completed').replace('分析中', '已完成') |
|
|
history[-1]["content"] = history[-1]["content"].replace(summary_start, updated_summary) |
|
|
|
|
|
|
|
|
history[-1]["content"] += "</div>\n</details>" |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"❌ {tool['name']} 查询失败: {str(e)}" |
|
|
history.append({"role": "assistant", "content": error_msg}) |
|
|
yield "", history |
|
|
|
|
|
|
|
|
history[-1]["content"] += "</div>\n</details>" |
|
|
all_results.append(f"{tool['name']}: 查询失败 - {str(e)}") |
|
|
|
|
|
|
|
|
time.sleep(0.5) |
|
|
|
|
|
|
|
|
client.close() |
|
|
|
|
|
|
|
|
final_summary_header = '''<div style="border: 2px solid #4CAF50; border-radius: 8px; padding: 20px; background-color: #f8fff8; margin: 15px 0;"> |
|
|
<h3 style="color: #2E7D32; text-align: center; margin-top: 0;">📈 最终总结</h3>''' |
|
|
history.append({"role": "assistant", "content": final_summary_header}) |
|
|
yield "", history |
|
|
|
|
|
summary_msg = "📊 正在分析所有数据并生成总结..." |
|
|
history.append({"role": "assistant", "content": summary_msg}) |
|
|
yield "", history |
|
|
|
|
|
|
|
|
all_results_text = "\n".join(all_results) |
|
|
summary_prompt = f"""用户问题: {user_input} |
|
|
|
|
|
收集到的金融数据: |
|
|
{all_results_text} |
|
|
|
|
|
请根据以上数据回答用户问题,提供简洁明了的总结。""" |
|
|
|
|
|
|
|
|
bot_response = "" |
|
|
for chunk in call_llm(summary_prompt): |
|
|
if not chunk or chunk == "[DONE]": |
|
|
continue |
|
|
|
|
|
if chunk.startswith("data: "): |
|
|
chunk = chunk[6:] |
|
|
|
|
|
try: |
|
|
response_data = json.loads(chunk) |
|
|
delta = response_data.get("choices", [{}])[0].get("delta", {}) |
|
|
content = delta.get("content", "") |
|
|
if content: |
|
|
bot_response += content |
|
|
|
|
|
history[-1]["content"] = "📊 正在分析所有数据并生成总结...\n" + bot_response |
|
|
yield "", history |
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
|
|
|
|
|
|
final_footer = "</div>" |
|
|
history.append({"role": "assistant", "content": bot_response + final_footer}) |
|
|
yield "", history |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"❌ 错误: {str(e)}" |
|
|
history.append({"role": "assistant", "content": error_msg}) |
|
|
yield "", history |
|
|
|