JC321 commited on
Commit
b94319a
·
verified ·
1 Parent(s): f04e9b1

Delete EasyFinancialAgent/chat_direct copy 2.py

Browse files
EasyFinancialAgent/chat_direct copy 2.py DELETED
@@ -1,1112 +0,0 @@
1
- """
2
- Financial AI Assistant - Direct Method Library (不依赖 HTTP)
3
- 直接导入并调用 easy_financial_mcp.py 中的函数
4
- 支持本地和 HF Space 部署
5
- """
6
-
7
- import sys
8
- from pathlib import Path
9
- import os
10
- import json
11
- from dotenv import load_dotenv
12
- from huggingface_hub import InferenceClient
13
- import requests
14
- import warnings
15
-
16
- # 抑削 asyncio 警告
17
- warnings.filterwarnings('ignore', category=DeprecationWarning)
18
- os.environ['PYTHONWARNINGS'] = 'ignore'
19
-
20
- # 先加载 .env 文件
21
- load_dotenv()
22
-
23
- # 添加服务模块路径
24
- PROJECT_ROOT = Path(__file__).parent.parent.absolute()
25
- sys.path.insert(0, str(PROJECT_ROOT))
26
-
27
- # 直接导入 MCP 中定义的函数
28
- try:
29
- from EasyFinancialAgent.easy_financial_mcp import (
30
- search_company as _search_company,
31
- get_company_info as _get_company_info,
32
- get_company_filings as _get_company_filings,
33
- get_financial_data as _get_financial_data,
34
- extract_financial_metrics as _extract_financial_metrics,
35
- get_latest_financial_data as _get_latest_financial_data,
36
- advanced_search_company as _advanced_search_company
37
- )
38
- MCP_DIRECT_AVAILABLE = True
39
- print("[FinancialAI] ✓ Direct MCP functions imported successfully")
40
- except ImportError as e:
41
- MCP_DIRECT_AVAILABLE = False
42
- print(f"[FinancialAI] ✗ Failed to import MCP functions: {e}")
43
- # 定义占位符函数
44
- def _advanced_search_company(x):
45
- return {"error": "MCP not available"}
46
- def _get_company_info(x):
47
- return {"error": "MCP not available"}
48
- def _get_company_filings(x, y=None):
49
- return {"error": "MCP not available"}
50
- def _get_financial_data(x, y):
51
- return {"error": "MCP not available"}
52
- def _get_latest_financial_data(x):
53
- return {"error": "MCP not available"}
54
- def _extract_financial_metrics(x, y=3):
55
- return {"error": "MCP not available"}
56
-
57
-
58
- # ============================================================
59
- # 便捷方法 - 公司搜索相关
60
- # ============================================================
61
-
62
- def search_company_direct(company_input):
63
- """
64
- 批量搜索公司信息(直接调用)
65
-
66
- 使用 advanced_search_company 工具,支持公司名称、Ticker 或 CIK 代码
67
-
68
- Args:
69
- company_input: 公司名称、Ticker 代码或 CIK 代码
70
-
71
- Returns:
72
- 批量搜索结果
73
-
74
- Example:
75
- result = search_company_direct("Apple")
76
- result = search_company_direct("AAPL")
77
- result = search_company_direct("0000320193")
78
- """
79
- if not MCP_DIRECT_AVAILABLE:
80
- return {"error": "MCP functions not available"}
81
-
82
- try:
83
- result = _advanced_search_company(company_input)
84
- return [result]
85
- except Exception as e:
86
- return {"error": str(e)}
87
-
88
-
89
- def get_company_info_direct(cik):
90
- """
91
- 获取公司详细信息(直接调用)
92
-
93
- Args:
94
- cik: 公司 CIK 代码
95
-
96
- Returns:
97
- 公司信息
98
-
99
- Example:
100
- result = get_company_info_direct("0000320193")
101
- """
102
- if not MCP_DIRECT_AVAILABLE:
103
- return {"error": "MCP functions not available"}
104
-
105
- try:
106
- return _get_company_info(cik)
107
- except Exception as e:
108
- return {"error": str(e)}
109
-
110
-
111
- def get_company_filings_direct(cik):
112
- """
113
- 获取公司 SEC 文件列表(直接调用)
114
-
115
- Args:
116
- cik: 公司 CIK 代码
117
-
118
- Returns:
119
- 文件列表
120
-
121
- Example:
122
- result = get_company_filings_direct("0000320193")
123
- """
124
- if not MCP_DIRECT_AVAILABLE:
125
- return {"error": "MCP functions not available"}
126
-
127
- try:
128
- return _get_company_filings(cik)
129
- except Exception as e:
130
- return {"error": str(e)}
131
-
132
-
133
- def advanced_search_company_detailed(company_input):
134
- """
135
- 高级公司搜索 - 支持公司名称、Ticker 或 CIK 的强大搜索方法
136
-
137
- 不同于 search_company_direct,该方法来自 EasyReportDataMCP 中的 mcp_server_fastmcp
138
- 更具有灵活性,可以自动检测输入的类型
139
-
140
- Args:
141
- company_input: 公司名称 ("Tesla", "Apple Inc")
142
- Ticker 代码 ("TSLA", "AAPL", "MSFT")
143
- CIK 代码 ("0001318605", "0000320193")
144
-
145
- Returns:
146
- dict: 包含以下信息:
147
- - cik: 公司的 Central Index Key
148
- - name: 办公室注册名称
149
- - tickers: 股票代码
150
- - sic: Standard Industrial Classification 代码
151
- - sic_description: 行业/行业描述
152
-
153
- Example:
154
- # 按公司名称搜索
155
- result = advanced_search_company_detailed("Tesla")
156
- # 按 Ticker 搜索
157
- result = advanced_search_company_detailed("TSLA")
158
- # 按 CIK 搜索
159
- result = advanced_search_company_detailed("0001318605")
160
- """
161
- if not MCP_DIRECT_AVAILABLE:
162
- return {"error": "MCP functions not available"}
163
-
164
- try:
165
- # 直接调用 advanced_search_company 工具
166
- result = _advanced_search_company(company_input)
167
- return result
168
- except Exception as e:
169
- import traceback
170
- return {
171
- "error": str(e),
172
- "traceback": traceback.format_exc()
173
- }
174
-
175
-
176
- def format_search_result(search_result):
177
- """
178
- 提取并格式化搜索结果
179
-
180
- 将 advanced_search_company 的结果转换为标准格式:
181
- [{company_name: str, cik: str, ticker: str}]
182
-
183
- Args:
184
- search_result: advanced_search_company 的返回结果
185
- 格式: {'cik': '...', 'name': '...', 'tickers': [...], ...}
186
-
187
- Returns:
188
- list[dict]: 格式化的结果
189
- [
190
- {
191
- 'company_name': str, # 公司名称
192
- 'cik': str, # CIK 代码
193
- 'ticker': str # 第一个股票代码
194
- }
195
- ]
196
-
197
- Example:
198
- search_result = {'cik': '0001577552', 'name': 'Alibaba Group Holding Ltd', 'tickers': ['BABA'], '_source': 'company_tickers_cache'}
199
- formatted = format_search_result(search_result)
200
- # 输出: [{'company_name': 'Alibaba Group Holding Ltd', 'cik': '0001577552', 'ticker': 'BABA'}]
201
- """
202
- # 处理错误情况
203
- if isinstance(search_result, dict) and 'error' in search_result:
204
- return []
205
-
206
- # 处理列表情况
207
- if isinstance(search_result, list):
208
- formatted_list = []
209
- for item in search_result:
210
- formatted_item = format_search_result(item)
211
- formatted_list.extend(formatted_item)
212
- return formatted_list
213
-
214
- # 处理单个字典
215
- if not isinstance(search_result, dict):
216
- return []
217
-
218
- try:
219
- company_name = search_result.get('name', '')
220
- cik = search_result.get('cik', '')
221
- tickers = search_result.get('tickers', [])
222
-
223
- # 取数组的第一个元素,或使用空字符串
224
- ticker = tickers[0] if isinstance(tickers, list) and len(tickers) > 0 else ''
225
-
226
- return [{
227
- 'company_name': company_name,
228
- 'cik': cik,
229
- 'ticker': ticker
230
- }]
231
- except Exception as e:
232
- return []
233
-
234
-
235
- def search_and_format(company_input):
236
- """
237
- 搎合搜索并立即格式化结果
238
-
239
- 一个一步到位的便法方法,执行搜索并格式化结果
240
-
241
- Args:
242
- company_input: 公司名称、Ticker 或 CIK
243
-
244
- Returns:
245
- list[dict]: 格式化的结果
246
-
247
- Example:
248
- result = search_and_format('BABA')
249
- # 输出: [{'company_name': 'Alibaba Group Holding Ltd', 'cik': '0001577552', 'ticker': 'BABA'}]
250
- """
251
- # 执行搜索
252
- search_result = advanced_search_company_detailed(company_input)
253
-
254
- # 检查是否有错误
255
- if isinstance(search_result, dict) and 'error' in search_result:
256
- return []
257
-
258
- # 格式化结果
259
- return format_search_result(search_result)
260
-
261
-
262
- # ============================================================
263
- # 便捷方法 - 财务数据相关
264
- # ============================================================
265
-
266
- def get_latest_financial_data_direct(cik):
267
- """
268
- 获取公司最新财务数据(直接调用)
269
-
270
- Args:
271
- cik: 公司 CIK 代码
272
-
273
- Returns:
274
- 最新财务数据
275
-
276
- Example:
277
- result = get_latest_financial_data_direct("0000320193")
278
- """
279
- if not MCP_DIRECT_AVAILABLE:
280
- return {"error": "MCP functions not available"}
281
-
282
- try:
283
- return _get_latest_financial_data(cik)
284
- except Exception as e:
285
- return {"error": str(e)}
286
-
287
-
288
- def extract_financial_metrics_direct(cik, years=5):
289
- """
290
- 提取多年财务指标趋势(直接调用)
291
-
292
- Args:
293
- cik: 公司 CIK 代码
294
- years: 年数(默认 3 年)
295
-
296
- Returns:
297
- 财务指标数据
298
-
299
- Example:
300
- result = extract_financial_metrics_direct("0000320193", years=5)
301
- """
302
- if not MCP_DIRECT_AVAILABLE:
303
- return {"error": "MCP functions not available"}
304
-
305
- try:
306
- return _extract_financial_metrics(cik, years)
307
- except Exception as e:
308
- return {"error": str(e)}
309
-
310
-
311
- # ============================================================
312
- # 高级方法 - 综合查询
313
- # ============================================================
314
-
315
- def query_company_direct(company_input, get_filings=True, get_metrics=True):
316
- """
317
- 综合查询公司信息(直接调用)
318
- 包括搜索、基本信息、文件列表和财务指标
319
-
320
- Args:
321
- company_input: 公司名称或代码
322
- get_filings: 是否获取文件列表
323
- get_metrics: 是否获取财务指标
324
-
325
- Returns:
326
- 综合结果字典,包含 search, company_info, filings, metrics
327
-
328
- Example:
329
- result = query_company_direct("Apple", get_filings=True, get_metrics=True)
330
- """
331
- from datetime import datetime
332
-
333
- result = {
334
- "timestamp": datetime.now().isoformat(),
335
- "query_input": company_input,
336
- "status": "success",
337
- "data": {
338
- "company_search": None,
339
- "company_info": None,
340
- "filings": None,
341
- "metrics": None
342
- },
343
- "errors": []
344
- }
345
-
346
- if not MCP_DIRECT_AVAILABLE:
347
- result["status"] = "error"
348
- result["errors"].append("MCP functions not available")
349
- return result
350
-
351
- try:
352
- # 1. 搜索公司
353
- search_result = search_company_direct(company_input)
354
- if "error" in search_result:
355
- result["errors"].append(f"Search error: {search_result['error']}")
356
- result["status"] = "error"
357
- return result
358
-
359
- result["data"]["company_search"] = search_result
360
-
361
- # 从搜索结果提取 CIK
362
- cik = None
363
- if isinstance(search_result, dict):
364
- cik = search_result.get("cik")
365
- elif isinstance(search_result, (list, tuple)) and len(search_result) > 0:
366
- # 从列表中获取第一个元素
367
- try:
368
- first_item = search_result[0] if isinstance(search_result, (list, tuple)) else None
369
- if isinstance(first_item, dict):
370
- cik = first_item.get("cik")
371
- except (IndexError, TypeError):
372
- pass
373
-
374
- if not cik:
375
- result["errors"].append("Could not extract CIK from search result")
376
- result["status"] = "error"
377
- return result
378
-
379
- # 2. 获取公司信息
380
- company_info = get_company_info_direct(cik)
381
- if "error" not in company_info:
382
- result["data"]["company_info"] = company_info
383
- else:
384
- result["errors"].append(f"Failed to get company info: {company_info.get('error')}")
385
-
386
- # 3. 获取文件列表
387
- if get_filings:
388
- filings = get_company_filings_direct(cik)
389
- if "error" not in filings:
390
- result["data"]["filings"] = filings
391
- else:
392
- result["errors"].append(f"Failed to get filings: {filings.get('error')}")
393
-
394
- # 4. 获取财务指标
395
- if get_metrics:
396
- metrics = extract_financial_metrics_direct(cik, years=3)
397
- if "error" not in metrics:
398
- result["data"]["metrics"] = metrics
399
- else:
400
- result["errors"].append(f"Failed to get metrics: {metrics.get('error')}")
401
-
402
- except Exception as e:
403
- result["status"] = "error"
404
- result["errors"].append(f"Exception: {str(e)}")
405
- import traceback
406
- result["errors"].append(traceback.format_exc())
407
-
408
- return result
409
-
410
-
411
- # ============================================================
412
- # LLM 模型配置与初始化
413
- # ============================================================
414
-
415
- # 初始化 LLM 客户端
416
- def _init_llm_client():
417
- """初始化 LLM 客户端"""
418
- global llm_client
419
- hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
420
- llm_client = None
421
- try:
422
- if hf_token:
423
- llm_client = InferenceClient(api_key=hf_token)
424
- print("[FinancialAI] ✓ LLM client initialized with HF_TOKEN")
425
- return True
426
- else:
427
- print("[FinancialAI] ⚠ Warning: HF_TOKEN not found, LLM features disabled")
428
- return False
429
- except Exception as e:
430
- print(f"[FinancialAI] ✗ Failed to initialize LLM client: {e}")
431
- return False
432
-
433
- # 全局 llm_client 变量
434
- llm_client = None
435
- _init_llm_client()
436
-
437
-
438
- def get_system_prompt():
439
- """生成系统提示词"""
440
- from datetime import datetime
441
- current_date = datetime.now().strftime("%Y-%m-%d")
442
- return f"""You are a financial analysis expert. Today is {current_date}.
443
- Your role:
444
- - Analyze company financial data, reports, and market news
445
- - Provide investment insights based on factual data
446
- - Be concise, objective, and data-driven
447
- - Always include disclaimers about market risks
448
-
449
- ⚠️ IMPORTANT: You have a maximum of 5 tool calls. Choose the MOST RELEVANT tools carefully:
450
- - Use 'advanced_search_company' ONLY if you need to find a company's CIK
451
- - Use 'extract_financial_metrics' for comprehensive multi-year financial analysis (RECOMMENDED for most queries)
452
- - Use 'get_latest_financial_data' for quick recent snapshot
453
- - Use 'get_quote' for real-time stock price
454
- - Use 'get_company_news' for company-specific news
455
- - Use 'get_market_news' for general market trends
456
-
457
- Prioritize the most important tools for the user's question. Avoid redundant calls.
458
- Output should be in English."""
459
-
460
-
461
- def analyze_company_with_llm(company_input, analysis_type="summary"):
462
- """
463
- 使用 LLM 分析公司信息
464
-
465
- Args:
466
- company_input: 公司名称或代码
467
- analysis_type: 分析类型 ("summary", "investment", "risks")
468
-
469
- Returns:
470
- LLM 分析结果
471
-
472
- Example:
473
- result = analyze_company_with_llm("Apple", "investment")
474
- """
475
- if not llm_client:
476
- return {"error": "LLM client not available"}
477
-
478
- if not MCP_DIRECT_AVAILABLE:
479
- return {"error": "MCP functions not available"}
480
-
481
- try:
482
- # 先获取公司财务数据
483
- company_data = get_company_summary_direct(company_input)
484
- if company_data["status"] == "error":
485
- return {"error": f"Failed to fetch company data: {company_data['errors']}"}
486
-
487
- # 构建提示
488
- data_str = json.dumps(company_data["data"], ensure_ascii=False, indent=2)
489
-
490
- if analysis_type == "investment":
491
- prompt = f"""
492
- Based on the following company financial data, provide an investment recommendation:
493
-
494
- {data_str}
495
-
496
- Provide:
497
- 1. Investment Recommendation (Buy/Hold/Sell)
498
- 2. Key Strengths and Weaknesses
499
- 3. Price Target Range
500
- 4. Risk Assessment
501
- 5. Risk Disclaimer
502
- """
503
- elif analysis_type == "risks":
504
- prompt = f"""
505
- Based on the following company data, analyze the key risks:
506
-
507
- {data_str}
508
-
509
- Identify:
510
- 1. Financial Risks
511
- 2. Market Risks
512
- 3. Operational Risks
513
- 4. Mitigation Strategies
514
- 5. Risk Disclaimer
515
- """
516
- else: # summary
517
- prompt = f"""
518
- Provide a financial summary of the following company:
519
-
520
- {data_str}
521
-
522
- Include:
523
- 1. Company Overview
524
- 2. Financial Health
525
- 3. Recent Performance
526
- 4. Investment Outlook
527
- """
528
-
529
- # 调用 LLM
530
- response = llm_client.chat.completions.create(
531
- model="Qwen/Qwen2.5-72B-Instruct",
532
- messages=[
533
- {"role": "system", "content": get_system_prompt()},
534
- {"role": "user", "content": prompt}
535
- ],
536
- max_tokens=1500,
537
- temperature=0.7,
538
- top_p=0.95,
539
- stream=False
540
- )
541
-
542
- return {
543
- "company": company_input,
544
- "analysis_type": analysis_type,
545
- "analysis": response.choices[0].message.content,
546
- "data_used": company_data["data"]
547
- }
548
-
549
- except Exception as e:
550
- return {"error": f"LLM analysis failed: {str(e)}"}
551
-
552
-
553
- # ============================================================
554
- # 便捷方法 - 获取单一时期财务数据
555
- # ============================================================
556
-
557
- def get_financial_data_direct(cik, period):
558
- """
559
- 获取指定时期的财务数据(直接调用)
560
-
561
- Args:
562
- cik: 公司 CIK 代码
563
- period: 时期 (e.g., "2024", "2024Q3")
564
-
565
- Returns:
566
- 财务数据
567
-
568
- Example:
569
- result = get_financial_data_direct("0000320193", "2024")
570
- """
571
- if not MCP_DIRECT_AVAILABLE:
572
- return {"error": "MCP functions not available"}
573
-
574
- try:
575
- return _get_financial_data(cik, period)
576
- except Exception as e:
577
- return {"error": str(e)}
578
-
579
-
580
- # ============================================================
581
- # 便捷方法 - 获取文件列表
582
- # ============================================================
583
-
584
- def get_company_filings_with_form_direct(cik, form_types=None):
585
- """
586
- 获取指定类型的公司 SEC 文件列表(直接调用)
587
-
588
- Args:
589
- cik: 公司 CIK 代码
590
- form_types: 表单类型列表 (e.g., ["10-K", "10-Q"])
591
-
592
- Returns:
593
- 文件列表
594
-
595
- Example:
596
- result = get_company_filings_with_form_direct("0000320193", ["10-K"])
597
- """
598
- if not MCP_DIRECT_AVAILABLE:
599
- return {"error": "MCP functions not available"}
600
-
601
- try:
602
- return _get_company_filings(cik, form_types)
603
- except Exception as e:
604
- return {"error": str(e)}
605
-
606
-
607
- # ============================================================
608
- # 便捷方法 - 轻量级查询
609
- # ============================================================
610
-
611
- def get_company_summary_direct(company_input):
612
- """
613
- 获取公司简要摘要信息(轻量级查询,仅搜索和基本信息)
614
-
615
- Args:
616
- company_input: 公司名称或代码
617
-
618
- Returns:
619
- 公司摘要数据
620
-
621
- Example:
622
- result = get_company_summary_direct("Apple")
623
- """
624
- from datetime import datetime
625
-
626
- result = {
627
- "timestamp": datetime.now().isoformat(),
628
- "query_input": company_input,
629
- "status": "success",
630
- "data": {
631
- "company_search": None,
632
- "company_info": None
633
- },
634
- "errors": []
635
- }
636
-
637
- if not MCP_DIRECT_AVAILABLE:
638
- result["status"] = "error"
639
- result["errors"].append("MCP functions not available")
640
- return result
641
-
642
- try:
643
- # 1. 搜索公司
644
- search_result = search_company_direct(company_input)
645
- if "error" in search_result:
646
- result["errors"].append(f"Search error: {search_result['error']}")
647
- result["status"] = "error"
648
- return result
649
-
650
- result["data"]["company_search"] = search_result
651
-
652
- # 从搜索结果提取 CIK
653
- cik = None
654
- if isinstance(search_result, dict):
655
- cik = search_result.get("cik")
656
- elif isinstance(search_result, (list, tuple)) and len(search_result) > 0:
657
- try:
658
- first_item = search_result[0]
659
- if isinstance(first_item, dict):
660
- cik = first_item.get("cik")
661
- except (IndexError, TypeError):
662
- pass
663
-
664
- if not cik:
665
- result["errors"].append("Could not extract CIK from search result")
666
- result["status"] = "error"
667
- return result
668
-
669
- # 2. 获取公司信息
670
- company_info = get_company_info_direct(cik)
671
- if "error" not in company_info:
672
- result["data"]["company_info"] = company_info
673
- else:
674
- result["errors"].append(f"Failed to get company info: {company_info.get('error')}")
675
-
676
- except Exception as e:
677
- result["status"] = "error"
678
- result["errors"].append(f"Exception: {str(e)}")
679
- import traceback
680
- result["errors"].append(traceback.format_exc())
681
-
682
- return result
683
-
684
-
685
- def get_financial_metrics_only_direct(company_input, years=5):
686
- """
687
- 获取公司财务指标趋势(仅财务指标,不获取文件列表)
688
-
689
- Args:
690
- company_input: 公司名称或代码
691
- years: 年数(默认 5 年)
692
-
693
- Returns:
694
- 财务指标数据
695
-
696
- Example:
697
- result = get_financial_metrics_only_direct("Apple", years=5)
698
- """
699
- from datetime import datetime
700
-
701
- result = {
702
- "timestamp": datetime.now().isoformat(),
703
- "query_input": company_input,
704
- "years": years,
705
- "status": "success",
706
- "data": None,
707
- "errors": []
708
- }
709
-
710
- if not MCP_DIRECT_AVAILABLE:
711
- result["status"] = "error"
712
- result["errors"].append("MCP functions not available")
713
- return result
714
-
715
- try:
716
- # 1. 搜索公司
717
- search_result = search_company_direct(company_input)
718
- if "error" in search_result:
719
- result["errors"].append(f"Search error: {search_result['error']}")
720
- result["status"] = "error"
721
- return result
722
-
723
- # 从搜索结果提取 CIK
724
- cik = None
725
- if isinstance(search_result, dict):
726
- cik = search_result.get("cik")
727
- elif isinstance(search_result, (list, tuple)) and len(search_result) > 0:
728
- try:
729
- first_item = search_result[0]
730
- if isinstance(first_item, dict):
731
- cik = first_item.get("cik")
732
- except (IndexError, TypeError):
733
- pass
734
-
735
- if not cik:
736
- result["errors"].append("Could not extract CIK from search result")
737
- result["status"] = "error"
738
- return result
739
-
740
- # 2. 获取财务指标
741
- metrics = extract_financial_metrics_direct(cik, years=years)
742
- if "error" in metrics:
743
- result["errors"].append(f"Failed to get metrics: {metrics['error']}")
744
- result["status"] = "error"
745
- else:
746
- result["data"] = metrics
747
-
748
- except Exception as e:
749
- result["status"] = "error"
750
- result["errors"].append(f"Exception: {str(e)}")
751
- import traceback
752
- result["errors"].append(traceback.format_exc())
753
-
754
- return result
755
-
756
-
757
- # ============================================================
758
- # 测试函数
759
- # ============================================================
760
-
761
- if __name__ == "__main__":
762
- print("\n" + "="*60)
763
- print("Financial AI Assistant - Direct Method Test")
764
- print("="*60)
765
-
766
- # 测试 1: 公司搜索
767
- print("\n1. 搜索公司 (Apple)...")
768
- result = search_company_direct("Apple")
769
- print(f" 结果: {result}")
770
-
771
- # 测试 2: 公司摘要
772
- print("\n2. 获取公司摘要信息 (Tesla)...")
773
- summary = get_company_summary_direct("Tesla")
774
- print(f" 状态: {summary['status']}")
775
- print(f" 数据: {summary['data']}")
776
- print(f" 错误: {summary['errors']}")
777
-
778
- # 测试 3: 财务指标
779
- print("\n3. 获取财务指标 (Microsoft)...")
780
- metrics = get_financial_metrics_only_direct("Microsoft", years=3)
781
- print(f" 状态: {metrics['status']}")
782
- if metrics['status'] == 'success':
783
- print(f" 指标数据: {metrics['data']}")
784
- else:
785
- print(f" 错误: {metrics['errors']}")
786
-
787
- # 测试 4: 完整查询
788
- print("\n4. 获取 Amazon 完整信息...")
789
- full_query = query_company_direct("Amazon", get_filings=True, get_metrics=True)
790
- print(f" 状态: {full_query['status']}")
791
- print(f" 错误: {full_query['errors']}")
792
-
793
- # 测试 5: LLM 分析 - 摘要
794
- print("\n5. LLM 分析 - 公司摘要(Google)...")
795
- if llm_client:
796
- llm_result = analyze_company_with_llm("Google", "summary")
797
- if "error" in llm_result:
798
- print(f" 错误: {llm_result['error']}")
799
- else:
800
- print(f" 分析结果: {llm_result['analysis'][:200]}...")
801
- else:
802
- print(" LLM 客户端不可用")
803
-
804
- # 测试 6: LLM 分析 - 投资建议
805
- print("\n6. LLM 分析 - 投资建议(NVIDIA)...")
806
- if llm_client:
807
- llm_result = analyze_company_with_llm("NVIDIA", "investment")
808
- if "error" in llm_result:
809
- print(f" 错误: {llm_result['error']}")
810
- else:
811
- print(f" 分析结果: {llm_result['analysis'][:200]}...")
812
- else:
813
- print(" LLM 客户端不可用")
814
-
815
- # 测试 7: LLM 分析 - 风险评估
816
- print("\n7. LLM 分析 - 风险评估(Meta)...")
817
- if llm_client:
818
- llm_result = analyze_company_with_llm("Meta", "risks")
819
- if "error" in llm_result:
820
- print(f" 错误: {llm_result['error']}")
821
- else:
822
- print(f" 分析结果: {llm_result['analysis'][:200]}...")
823
- else:
824
- print(" LLM 客户端不可用")
825
-
826
- print("\n" + "="*60)
827
-
828
-
829
- # ============================================================
830
- # 完整对话引擎 - chatbot_response
831
- # ============================================================
832
-
833
- # Token 限制配置
834
- MAX_TOTAL_TOKENS = 6000
835
- MAX_TOOL_RESULT_CHARS = 1500
836
- MAX_HISTORY_CHARS = 500
837
- MAX_HISTORY_TURNS = 2
838
- MAX_TOOL_ITERATIONS = 5 # ✅ 限制最多调用5个工具,确保选择最合适的工具
839
- MAX_OUTPUT_TOKENS = 2000
840
-
841
- # MCP 工具配置 - 包含财务数据和市场新闻工具
842
- MCP_TOOLS = [
843
- # 财务数据工具 (EasyReportDataMCP)
844
- {"type": "function", "function": {"name": "advanced_search_company", "description": "Search US companies by name, ticker, or CIK. Returns company information including CIK, name, tickers, and industry classification.", "parameters": {"type": "object", "properties": {"company_input": {"type": "string", "description": "Company name (e.g., 'Tesla'), ticker symbol (e.g., 'TSLA'), or CIK code (e.g., '0001318605')"}}, "required": ["company_input"]}}},
845
- {"type": "function", "function": {"name": "get_latest_financial_data", "description": "Get the most recent financial data for a company including revenue, net income, EPS, operating expenses, and cash flow.", "parameters": {"type": "object", "properties": {"cik": {"type": "string", "description": "Company CIK code (10-digit format, e.g., '0001318605')"}}, "required": ["cik"]}}},
846
- {"type": "function", "function": {"name": "extract_financial_metrics", "description": "Extract multi-year financial metrics trends showing historical performance over specified years.", "parameters": {"type": "object", "properties": {"cik": {"type": "string", "description": "Company CIK code (10-digit format)"}, "years": {"type": "integer", "description": "Number of years of data to retrieve (e.g., 3 or 5)", "default": 3}}, "required": ["cik", "years"]}}},
847
-
848
- # 市场和新闻工具 (MarketandStockMCP)
849
- {"type": "function", "function": {"name": "get_quote", "description": "Get real-time stock quote data including current price, daily change, high/low, and previous close. Use when users ask about current stock prices or market performance.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "Stock ticker symbol (e.g., 'AAPL', 'TSLA', 'MSFT')"}}, "required": ["symbol"]}}},
850
- {"type": "function", "function": {"name": "get_market_news", "description": "Get latest market news by category. Use when users ask about general market trends, forex, crypto, or M&A news.", "parameters": {"type": "object", "properties": {"category": {"type": "string", "enum": ["general", "forex", "crypto", "merger"], "description": "News category: general (stocks/economy), forex (currency), crypto (cryptocurrency), merger (M&A)", "default": "general"}, "min_id": {"type": "integer", "description": "Minimum news ID for pagination (default: 0)", "default": 0}}, "required": ["category"]}}},
851
- {"type": "function", "function": {"name": "get_company_news", "description": "Get company-specific news and announcements. Only available for North American companies. Use when users ask about specific company news.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "Company stock ticker symbol (e.g., 'AAPL', 'TSLA')"}, "from_date": {"type": "string", "description": "Start date in YYYY-MM-DD format (optional, defaults to 7 days ago)"}, "to_date": {"type": "string", "description": "End date in YYYY-MM-DD format (optional, defaults to today)"}}, "required": ["symbol"]}}}
852
- ]
853
-
854
-
855
- def truncate_text(text, max_chars, suffix="...[truncated]"):
856
- """截断文本到指定长度"""
857
- text = str(text)
858
- if len(text) <= max_chars:
859
- return text
860
- return text[:max_chars] + suffix
861
-
862
-
863
- def call_mcp_tool(tool_name, arguments):
864
- """直接调用 MCP 工具函数(不通过HTTP)"""
865
- try:
866
- # ✅ 财务数据工具 - 直接调用 Python 函数
867
- if tool_name == "advanced_search_company":
868
- company_input = arguments.get("company_input", "")
869
- return _advanced_search_company(company_input)
870
-
871
- elif tool_name == "get_latest_financial_data":
872
- cik = arguments.get("cik", "")
873
- return _get_latest_financial_data(cik)
874
-
875
- elif tool_name == "extract_financial_metrics":
876
- cik = arguments.get("cik", "")
877
- years = arguments.get("years", 3)
878
- return _extract_financial_metrics(cik, years)
879
-
880
- # ✅ 市场和新闻工具 - 直接调用 Python 函数
881
- elif tool_name == "get_quote":
882
- from MarketandStockMCP.news_quote_mcp import get_quote
883
- symbol = arguments.get("symbol", "")
884
- return get_quote(symbol)
885
-
886
- elif tool_name == "get_market_news":
887
- from MarketandStockMCP.news_quote_mcp import get_market_news
888
- category = arguments.get("category", "general")
889
- min_id = arguments.get("min_id", 0)
890
- return get_market_news(category, min_id)
891
-
892
- elif tool_name == "get_company_news":
893
- from MarketandStockMCP.news_quote_mcp import get_company_news
894
- symbol = arguments.get("symbol", "")
895
- from_date = arguments.get("from_date")
896
- to_date = arguments.get("to_date")
897
- return get_company_news(symbol, from_date, to_date)
898
-
899
- else:
900
- return {"error": f"Unknown tool: {tool_name}"}
901
-
902
- except Exception as e:
903
- import traceback
904
- return {
905
- "error": f"{str(e)}",
906
- "traceback": traceback.format_exc()[:500]
907
- }
908
-
909
-
910
- def chatbot_response(message, history=None):
911
- """
912
- AI 助手主函数(完整对话引擎)
913
- 支持多轮对话、动态工具调用、流式输出
914
-
915
- Args:
916
- message: 用户消息
917
- history: 对话历史,格式: [(user_msg, assistant_msg), ...]
918
-
919
- Returns:
920
- 生成器,不断 yield 响应文本
921
-
922
- Example:
923
- for response in chatbot_response("What's Apple's revenue?", []):
924
- print(response)
925
- """
926
- if not llm_client:
927
- yield "❌ Error: LLM client not available"
928
- return
929
-
930
- if not MCP_DIRECT_AVAILABLE:
931
- yield "❌ Error: MCP functions not available"
932
- return
933
-
934
- try:
935
- messages = [{"role": "system", "content": get_system_prompt()}]
936
-
937
- # 添加历史(最近2轮) - 严格限制上下文长度
938
- if history:
939
- for item in history[-MAX_HISTORY_TURNS:]:
940
- if isinstance(item, (list, tuple)) and len(item) == 2:
941
- messages.append({"role": "user", "content": item[0]})
942
- assistant_msg = str(item[1])
943
- if len(assistant_msg) > MAX_HISTORY_CHARS:
944
- assistant_msg = truncate_text(assistant_msg, MAX_HISTORY_CHARS)
945
- messages.append({"role": "assistant", "content": assistant_msg})
946
-
947
- messages.append({"role": "user", "content": message})
948
-
949
- tool_calls_log = []
950
- final_response_content = None
951
-
952
- # LLM 调用循环(支持多轮工具调用)
953
- for iteration in range(MAX_TOOL_ITERATIONS):
954
- response = llm_client.chat.completions.create(
955
- model="Qwen/Qwen2.5-72B-Instruct",
956
- messages=messages,
957
- tools=MCP_TOOLS, # type: ignore
958
- max_tokens=MAX_OUTPUT_TOKENS,
959
- temperature=0.7,
960
- tool_choice="auto",
961
- stream=False
962
- )
963
-
964
- choice = response.choices[0]
965
-
966
- if choice.message.tool_calls:
967
- messages.append(choice.message)
968
-
969
- for tool_call in choice.message.tool_calls:
970
- tool_name = tool_call.function.name
971
- try:
972
- tool_args = json.loads(tool_call.function.arguments)
973
- except json.JSONDecodeError:
974
- tool_args = {}
975
-
976
- tool_result = call_mcp_tool(tool_name, tool_args)
977
-
978
- if isinstance(tool_result, dict) and "error" in tool_result:
979
- tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result, "error": True})
980
- result_for_llm = json.dumps({"error": tool_result.get("error", "Unknown error")}, ensure_ascii=False)
981
- else:
982
- result_str = json.dumps(tool_result, ensure_ascii=False)
983
-
984
- if len(result_str) > MAX_TOOL_RESULT_CHARS:
985
- if isinstance(tool_result, dict) and "text" in tool_result:
986
- truncated_text = truncate_text(tool_result["text"], MAX_TOOL_RESULT_CHARS - 50)
987
- tool_result_truncated = {"text": truncated_text, "_truncated": True}
988
- elif isinstance(tool_result, dict):
989
- truncated = {}
990
- char_count = 0
991
- for k, v in list(tool_result.items())[:8]:
992
- v_str = str(v)[:300]
993
- truncated[k] = v_str
994
- char_count += len(k) + len(v_str)
995
- if char_count > MAX_TOOL_RESULT_CHARS:
996
- break
997
- tool_result_truncated = {**truncated, "_truncated": True}
998
- else:
999
- tool_result_truncated = {"preview": truncate_text(result_str, MAX_TOOL_RESULT_CHARS), "_truncated": True}
1000
- result_for_llm = json.dumps(tool_result_truncated, ensure_ascii=False)
1001
- else:
1002
- result_for_llm = result_str
1003
-
1004
- tool_calls_log.append({"name": tool_name, "arguments": tool_args, "result": tool_result})
1005
-
1006
- messages.append({
1007
- "role": "tool",
1008
- "name": tool_name,
1009
- "content": result_for_llm,
1010
- "tool_call_id": tool_call.id
1011
- })
1012
-
1013
- continue
1014
- else:
1015
- final_response_content = choice.message.content
1016
- break
1017
-
1018
- response_prefix = ""
1019
-
1020
- if tool_calls_log:
1021
- # ✅ 可折叠的工具调用显示,点击三角形展开/收起
1022
- tool_count = len(tool_calls_log)
1023
-
1024
- # 添加CSS样式,实现三角形旋转动画
1025
- response_prefix += """<style>
1026
- details.tools-container > summary::before {
1027
- content: '▶';
1028
- display: inline-block;
1029
- margin-right: 8px;
1030
- transition: transform 0.2s;
1031
- }
1032
- details.tools-container[open] > summary::before {
1033
- transform: rotate(90deg);
1034
- }
1035
- details.tools-container > summary {
1036
- list-style: none;
1037
- }
1038
- details.tools-container > summary::-webkit-details-marker {
1039
- display: none;
1040
- }
1041
- </style>
1042
- """
1043
-
1044
- response_prefix += f"""<div style='margin-bottom: 15px;'>
1045
- <details class='tools-container' open>
1046
- <summary style='background: #f0f0f0; padding: 8px 12px; border-radius: 6px; font-weight: 600; color: #333; cursor: pointer; user-select: none;'>
1047
- <span>🛠️ Tools Used ({tool_count}/{MAX_TOOL_ITERATIONS} calls)</span>
1048
- </summary>
1049
- <div style='margin-top: 8px;'>
1050
- """
1051
-
1052
- for idx, tool_call in enumerate(tool_calls_log):
1053
- args_json = json.dumps(tool_call['arguments'], ensure_ascii=False)
1054
- result_json = json.dumps(tool_call.get('result', {}), ensure_ascii=False, indent=2)
1055
- result_preview = result_json[:1500] + ('...' if len(result_json) > 1500 else '')
1056
- error_indicator = " ❌ Error" if tool_call.get('error') else ""
1057
-
1058
- response_prefix += f"""<details style='margin: 8px 0; border: 1px solid #ddd; border-radius: 6px; overflow: hidden;'>
1059
- <summary style='background: #fff; padding: 10px; cursor: pointer; user-select: none; list-style: none;'>
1060
- <strong style='color: #2c5aa0;'>📋 {idx+1}. {tool_call['name']}{error_indicator}</strong>
1061
- </summary>
1062
- <div style='background: #f9f9f9; padding: 12px;'>
1063
- <pre style='background: #fff; padding: 10px; overflow-x: auto; font-size: 0.85em;'>{result_preview}</pre>
1064
- </div>
1065
- </details>
1066
- """
1067
-
1068
- # ✅ 关闭外层details和div标签
1069
- response_prefix += """ </div>
1070
- </details>
1071
- </div>
1072
- ---
1073
- """
1074
-
1075
- yield response_prefix
1076
-
1077
- if final_response_content:
1078
- yield response_prefix + final_response_content
1079
- else:
1080
- try:
1081
- stream = llm_client.chat.completions.create(
1082
- model="Qwen/Qwen2.5-72B-Instruct",
1083
- messages=messages,
1084
- tools=None,
1085
- max_tokens=MAX_OUTPUT_TOKENS,
1086
- temperature=0.7,
1087
- stream=True
1088
- )
1089
-
1090
- accumulated_text = ""
1091
- for chunk in stream:
1092
- if chunk.choices and len(chunk.choices) > 0 and chunk.choices[0].delta.content:
1093
- accumulated_text += chunk.choices[0].delta.content
1094
- yield response_prefix + accumulated_text
1095
- except Exception:
1096
- final_resp = llm_client.chat.completions.create(
1097
- model="Qwen/Qwen2.5-72B-Instruct",
1098
- messages=messages,
1099
- tools=None,
1100
- max_tokens=MAX_OUTPUT_TOKENS,
1101
- temperature=0.7,
1102
- stream=False
1103
- )
1104
- yield response_prefix + (final_resp.choices[0].message.content or "")
1105
-
1106
- except Exception as e:
1107
- import traceback
1108
- error_detail = str(e)
1109
- if "500" in error_detail:
1110
- yield f"❌ Error: 模型服务器错误\n\n{error_detail[:200]}"
1111
- else:
1112
- yield f"❌ Error: {error_detail}\n\n{traceback.format_exc()[:500]}"