File size: 10,851 Bytes
c9edecd 74af71d c9edecd 2179d32 c9edecd 2179d32 c9edecd 74af71d c9edecd 2179d32 c9edecd 2179d32 c9edecd 74af71d c9edecd 2179d32 c9edecd 2179d32 c9edecd 74af71d c9edecd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
"""
智能报告缓存管理器
用于管理 Investment Suggestion 和 Analysis Report 的后台生成和缓存
"""
from datetime import datetime
from typing import Callable, Dict
import threading
import time
class ReportTask:
"""报告生成任务"""
def __init__(self, company: str, report_type: str, generator_func: Callable):
self.company = company
self.report_type = report_type
self.generator_func = generator_func
self.status = "pending" # pending, running, completed, error
self.result = None
self.error = None
self.start_time = None
self.end_time = None
self.thread = None
def run(self):
"""执行任务"""
self.status = "running"
self.start_time = datetime.now()
try:
self.result = self.generator_func()
self.status = "completed"
except Exception as e:
self.status = "error"
self.error = str(e)
import traceback
self.error_trace = traceback.format_exc()
finally:
self.end_time = datetime.now()
def get_age_seconds(self):
"""获取任务年龄(秒)"""
if self.end_time:
return (datetime.now() - self.end_time).total_seconds()
elif self.start_time:
return (datetime.now() - self.start_time).total_seconds()
return 0
class ReportCacheManager:
"""智能报告缓存管理器"""
def __init__(self, cache_ttl_seconds=3600, max_cache_size=50):
self.cache_ttl = cache_ttl_seconds
self.max_cache_size = max_cache_size
self.tasks: Dict[tuple, ReportTask] = {}
self.lock = threading.Lock()
self.stats = {
"cache_hits": 0,
"cache_misses": 0,
"background_completions": 0,
"total_requests": 0
}
def get_or_create_report(self, company: str, report_type: str, generator_func: Callable):
"""获取或创建报告"""
# ✅ 不使用with lock包裹整个函数,而是在关键操作时锁定
self.lock.acquire()
try:
self.stats["total_requests"] += 1
key = (company, report_type)
# 清理过期缓存
self._cleanup_expired_cache()
task = self.tasks.get(key)
# ✅ 场景1: 缓存已存在且已完成 - 立即返回,无需等待
if task and task.status == "completed":
if task.get_age_seconds() < self.cache_ttl:
self.stats["cache_hits"] += 1
print(f"✅ [Cache HIT] {company} - {report_type} (age: {task.get_age_seconds():.1f}s)")
result = task.result
self.lock.release()
yield result # ✅ 立即返回缓存,不阻塞
return
else:
print(f"⏰ [Cache EXPIRED] {company} - {report_type}")
del self.tasks[key]
task = None
# ✅ 场景2: 其他公司的任务正在运行 - 检查当前请求的公司是否有缓存
# 如果Intel正在生成,但用户切换到NVDA,应该先检查NVDA的缓存
if task and task.status == "running":
# ✅ 关键优化: 检查当前请求的公司是否有缓存
# 如果有,立即返回缓存;如果没有,显示正在等待的loading
self.stats["cache_hits"] += 1
print(f"🔄 [Task RUNNING] {task.company} - {report_type}, but {company} requested")
# 如果请求的是同一个公司,等待任务完成
if task.company == company:
self.lock.release()
yield self._get_loading_html(company, report_type, task.get_age_seconds())
# 等待后台任务完成
max_wait = 90
waited = 0
while task.status == "running" and waited < max_wait:
time.sleep(1)
waited += 1
yield self._get_loading_html(company, report_type, task.get_age_seconds())
if task.status == "completed":
with self.lock:
self.stats["background_completions"] += 1
print(f"✅ [Background COMPLETED] {company} - {report_type}")
yield task.result
return
elif task.status == "error":
print(f"❌ [Background ERROR] {company} - {report_type}: {task.error}")
yield self._get_error_html(company, report_type, task.error)
return
else:
# ✅ 不同公司: 启动新任务,让旧任务在后台继续
print(f"🆕 [Different company] Starting {company} while {task.company} is running")
# 旧任务继续在后台运行,不干扰
task = None # 重置task,下面会创建新任务
# 场景3: 之前失败了,重试
if task and task.status == "error":
print(f"🔄 [Retry after ERROR] {company} - {report_type}")
del self.tasks[key]
task = None
# ✅ 场景4: 缓存不存在,启动新任务
if not task:
self.stats["cache_misses"] += 1
print(f"🆕 [Cache MISS] {company} - {report_type} - Starting background generation")
task = ReportTask(company, report_type, generator_func)
self.tasks[key] = task
task.thread = threading.Thread(target=task.run, daemon=True)
task.thread.start()
self.lock.release() # ✅ 释放锁再 yield
# ✅ 立即yield第一个loading状态,不要等待
yield self._get_loading_html(company, report_type, 0)
# 等待任务完成
max_wait = 90
waited = 0
while task.status == "running" and waited < max_wait:
time.sleep(1)
waited += 1
yield self._get_loading_html(company, report_type, task.get_age_seconds())
if task.status == "completed":
print(f"✅ [NEW COMPLETED] {company} - {report_type}")
yield task.result
return
elif task.status == "error":
print(f"❌ [NEW ERROR] {company} - {report_type}: {task.error}")
yield self._get_error_html(company, report_type, task.error)
return
finally:
# 确保锁被释放
if self.lock.locked():
self.lock.release()
def _cleanup_expired_cache(self):
"""清理过期缓存"""
keys_to_remove = []
for key, task in self.tasks.items():
if task.status == "completed" and task.get_age_seconds() > self.cache_ttl:
keys_to_remove.append(key)
for key in keys_to_remove:
company, report_type = key
print(f"🗑️ [Cache CLEANUP] {company} - {report_type}")
del self.tasks[key]
# 限制缓存大小
if len(self.tasks) > self.max_cache_size:
completed_tasks = [(k, v) for k, v in self.tasks.items() if v.status == "completed"]
completed_tasks.sort(key=lambda x: x[1].end_time or datetime.min)
to_remove = len(self.tasks) - self.max_cache_size
for i in range(to_remove):
key, task = completed_tasks[i]
company, report_type = key
print(f"🗑️ [Cache SIZE LIMIT] {company} - {report_type}")
del self.tasks[key]
def _get_loading_html(self, company: str, report_type: str, elapsed_seconds: float):
"""生成加载状态HTML"""
report_name = "Investment Suggestion" if report_type == "suggestion" else "Analysis Report"
elapsed_str = f"{elapsed_seconds:.0f}s" if elapsed_seconds > 0 else "just started"
return f'''
<div style="display: flex; justify-content: center; align-items: center; height: 200px;">
<div style="text-align: center;">
<div class="loading-spinner" style="width: 40px; height: 40px; border: 4px solid #f3f3f3; border-top: 4px solid #3498db; border-radius: 50%; animation: spin 1s linear infinite; margin: 0 auto;"></div>
<p style="margin-top: 20px; color: #666;">
🤖 Generating {report_name} for <strong>{company}</strong>...<br>
<small>Elapsed: {elapsed_str}</small>
</p>
<style>
@keyframes spin {{
0% {{ transform: rotate(0deg); }}
100% {{ transform: rotate(360deg); }}
}}
</style>
</div>
</div>
'''
def _get_error_html(self, company: str, report_type: str, error: str):
"""生成错误状态HTML"""
report_name = "Investment Suggestion" if report_type == "suggestion" else "Analysis Report"
return f'''
<div style="padding: 20px; background-color: #fff3cd; border-left: 4px solid #ffc107; border-radius: 4px;">
<h4 style="margin-top: 0; color: #856404;">⚠️ Generation Failed</h4>
<p><strong>Report:</strong> {report_name}</p>
<p><strong>Company:</strong> {company}</p>
<p><strong>Error:</strong> {error}</p>
</div>
'''
def get_stats(self):
"""获取缓存统计"""
with self.lock:
total = self.stats["total_requests"]
hits = self.stats["cache_hits"]
misses = self.stats["cache_misses"]
hit_rate = (hits / total * 100) if total > 0 else 0
return {
**self.stats,
"hit_rate": f"{hit_rate:.1f}%",
"active_tasks": len([t for t in self.tasks.values() if t.status == "running"]),
"cached_reports": len([t for t in self.tasks.values() if t.status == "completed"])
}
|