rkihacker commited on
Commit
0e14740
·
verified ·
1 Parent(s): 6366f6c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -64
main.py CHANGED
@@ -18,13 +18,17 @@ logger = logging.getLogger(__name__)
18
  load_dotenv()
19
  LLM_API_KEY = os.getenv("LLM_API_KEY")
20
 
 
21
  if not LLM_API_KEY:
22
  raise RuntimeError("LLM_API_KEY must be set in a .env file.")
 
 
23
 
24
- # API URLs and Models
25
  SNAPZION_API_URL = "https://search.snapzion.com/get-snippets"
26
  LLM_API_URL = "https://api.inference.net/v1/chat/completions"
27
  LLM_MODEL = "mistralai/mistral-nemo-12b-instruct/fp-8"
 
28
 
29
  # Headers for external services
30
  SNAPZION_HEADERS = { 'Content-Type': 'application/json', 'User-Agent': 'AI-Deep-Research-Agent/1.0' }
@@ -39,10 +43,10 @@ class DeepResearchRequest(BaseModel):
39
  app = FastAPI(
40
  title="AI Deep Research API",
41
  description="Provides single-shot AI search and streaming deep research completions.",
42
- version="2.0.0"
43
  )
44
 
45
- # --- Core Service Functions (Reused and New) ---
46
 
47
  async def call_snapzion_search(session: aiohttp.ClientSession, query: str) -> list:
48
  try:
@@ -52,7 +56,7 @@ async def call_snapzion_search(session: aiohttp.ClientSession, query: str) -> li
52
  return data.get("organic_results", [])
53
  except Exception as e:
54
  logger.error(f"Snapzion search failed for query '{query}': {e}")
55
- return [] # Return empty list on failure instead of crashing
56
 
57
  async def scrape_url(session: aiohttp.ClientSession, url: str) -> str:
58
  if url.lower().endswith('.pdf'): return "Error: PDF content cannot be scraped."
@@ -69,16 +73,15 @@ async def scrape_url(session: aiohttp.ClientSession, url: str) -> str:
69
  return f"Error: {e}"
70
 
71
  async def search_and_scrape(session: aiohttp.ClientSession, query: str) -> tuple[str, list]:
72
- """Performs the search and scrape pipeline for a given query."""
73
  search_results = await call_snapzion_search(session, query)
74
- sources = search_results[:4] # Use top 4 sources per sub-query
75
  if not sources: return "", []
76
 
77
  scrape_tasks = [scrape_url(session, source["link"]) for source in sources]
78
  scraped_contents = await asyncio.gather(*scrape_tasks)
79
 
80
  context = "\n\n".join(
81
- f"Source [{i+1}] (from {sources[i]['link']}):\n{content}"
82
  for i, content in enumerate(scraped_contents) if not content.startswith("Error:")
83
  )
84
  return context, sources
@@ -86,35 +89,26 @@ async def search_and_scrape(session: aiohttp.ClientSession, query: str) -> tuple
86
  # --- Streaming Deep Research Logic ---
87
 
88
  async def run_deep_research_stream(query: str) -> AsyncGenerator[str, None]:
89
- """The main async generator for the deep research process."""
90
 
91
  def format_sse(data: dict) -> str:
92
- """Formats a dictionary as a Server-Sent Event string."""
93
  return f"data: {json.dumps(data)}\n\n"
94
 
95
  try:
96
  async with aiohttp.ClientSession() as session:
97
- # Step 1: Generate Sub-Questions
98
  yield format_sse({"event": "status", "data": "Generating research plan..."})
99
  sub_question_prompt = {
100
  "model": LLM_MODEL,
101
- "messages": [{
102
- "role": "user",
103
- "content": f"You are a research planner. Based on the user's query '{query}', generate a list of 3 to 4 crucial sub-questions that would form the basis of a comprehensive research report. Respond with ONLY a JSON array of strings. Example: [\"Question 1?\", \"Question 2?\"]"
104
- }]
105
  }
106
  async with session.post(LLM_API_URL, headers=LLM_HEADERS, json=sub_question_prompt) as response:
107
  response.raise_for_status()
108
  result = await response.json()
109
- try:
110
- sub_questions = json.loads(result['choices'][0]['message']['content'])
111
- except (json.JSONDecodeError, IndexError):
112
- yield format_sse({"event": "error", "data": "Failed to parse sub-questions from LLM."})
113
- return
114
 
115
  yield format_sse({"event": "plan", "data": sub_questions})
116
 
117
- # Step 2: Concurrently research all sub-questions
118
  research_tasks = [search_and_scrape(session, sq) for sq in sub_questions]
119
  all_research_results = []
120
 
@@ -127,7 +121,13 @@ async def run_deep_research_stream(query: str) -> AsyncGenerator[str, None]:
127
  yield format_sse({"event": "status", "data": "Consolidating research..."})
128
  full_context = "\n\n---\n\n".join(res[0] for res in all_research_results if res[0])
129
  all_sources = [source for res in all_research_results for source in res[1]]
130
- unique_sources = list({s['link']: s for s in all_sources}.values()) # Deduplicate sources
 
 
 
 
 
 
131
 
132
  if not full_context.strip():
133
  yield format_sse({"event": "error", "data": "Failed to gather any research context."})
@@ -135,38 +135,28 @@ async def run_deep_research_stream(query: str) -> AsyncGenerator[str, None]:
135
 
136
  # Step 4: Generate the final report with streaming
137
  yield format_sse({"event": "status", "data": "Generating final report..."})
 
138
 
139
- final_report_prompt = f"""
140
- You are a research analyst. Your task is to synthesize the provided context into a comprehensive, well-structured report on the topic: "{query}".
141
- Use the context below exclusively. Do not use outside knowledge. Structure the report with markdown headings.
142
-
143
- ## Research Context ##
144
- {full_context}
145
- """
146
-
147
- final_report_payload = {
148
- "model": LLM_MODEL,
149
- "messages": [{"role": "user", "content": final_report_prompt}],
150
- "stream": True # Enable streaming from the LLM
151
- }
152
 
153
  async with session.post(LLM_API_URL, headers=LLM_HEADERS, json=final_report_payload) as response:
154
- response.raise_for_status()
 
 
 
 
 
155
  async for line in response.content:
 
156
  if line.strip():
157
- # The inference API might wrap its stream chunks in a 'data: ' prefix
158
  line_str = line.decode('utf-8').strip()
159
- if line_str.startswith('data:'):
160
- line_str = line_str[5:].strip()
161
- if line_str == "[DONE]":
162
- break
163
  try:
164
  chunk = json.loads(line_str)
165
  content = chunk.get("choices", [{}])[0].get("delta", {}).get("content")
166
- if content:
167
- yield format_sse({"event": "chunk", "data": content})
168
- except json.JSONDecodeError:
169
- continue # Ignore empty or malformed lines
170
 
171
  yield format_sse({"event": "sources", "data": unique_sources})
172
 
@@ -178,25 +168,6 @@ Use the context below exclusively. Do not use outside knowledge. Structure the r
178
 
179
 
180
  # --- API Endpoints ---
181
-
182
- @app.get("/", include_in_schema=False)
183
- def root():
184
- return {"message": "AI Deep Research API is active. See /docs for details."}
185
-
186
  @app.post("/v1/deepresearch/completions")
187
  async def deep_research_endpoint(request: DeepResearchRequest):
188
- """
189
- Performs a multi-step, streaming deep research task.
190
-
191
- **Events Streamed:**
192
- - `status`: Provides updates on the current stage of the process.
193
- - `plan`: The list of sub-questions that will be researched.
194
- - `chunk`: A piece of the final generated report.
195
- - `sources`: The list of web sources used for the report.
196
- - `error`: Indicates a fatal error occurred.
197
- - `done`: Signals the end of the stream.
198
- """
199
- return StreamingResponse(
200
- run_deep_research_stream(request.query),
201
- media_type="text/event-stream"
202
- )
 
18
  load_dotenv()
19
  LLM_API_KEY = os.getenv("LLM_API_KEY")
20
 
21
+ # ***** CHANGE 1: Add API Key loading confirmation *****
22
  if not LLM_API_KEY:
23
  raise RuntimeError("LLM_API_KEY must be set in a .env file.")
24
+ else:
25
+ logger.info(f"LLM API Key loaded successfully (starts with: {LLM_API_KEY[:4]}...).")
26
 
27
+ # API URLs, Models, and a new constant for context size
28
  SNAPZION_API_URL = "https://search.snapzion.com/get-snippets"
29
  LLM_API_URL = "https://api.inference.net/v1/chat/completions"
30
  LLM_MODEL = "mistralai/mistral-nemo-12b-instruct/fp-8"
31
+ MAX_CONTEXT_CHAR_LENGTH = 120000 # Safeguard: roughly 30k tokens
32
 
33
  # Headers for external services
34
  SNAPZION_HEADERS = { 'Content-Type': 'application/json', 'User-Agent': 'AI-Deep-Research-Agent/1.0' }
 
43
  app = FastAPI(
44
  title="AI Deep Research API",
45
  description="Provides single-shot AI search and streaming deep research completions.",
46
+ version="2.1.0" # Version bump for new robustness feature
47
  )
48
 
49
+ # --- Core Service Functions (Unchanged) ---
50
 
51
  async def call_snapzion_search(session: aiohttp.ClientSession, query: str) -> list:
52
  try:
 
56
  return data.get("organic_results", [])
57
  except Exception as e:
58
  logger.error(f"Snapzion search failed for query '{query}': {e}")
59
+ return []
60
 
61
  async def scrape_url(session: aiohttp.ClientSession, url: str) -> str:
62
  if url.lower().endswith('.pdf'): return "Error: PDF content cannot be scraped."
 
73
  return f"Error: {e}"
74
 
75
  async def search_and_scrape(session: aiohttp.ClientSession, query: str) -> tuple[str, list]:
 
76
  search_results = await call_snapzion_search(session, query)
77
+ sources = search_results[:4]
78
  if not sources: return "", []
79
 
80
  scrape_tasks = [scrape_url(session, source["link"]) for source in sources]
81
  scraped_contents = await asyncio.gather(*scrape_tasks)
82
 
83
  context = "\n\n".join(
84
+ f"Source Details: Title '{sources[i]['title']}', URL '{sources[i]['link']}'\nContent:\n{content}"
85
  for i, content in enumerate(scraped_contents) if not content.startswith("Error:")
86
  )
87
  return context, sources
 
89
  # --- Streaming Deep Research Logic ---
90
 
91
  async def run_deep_research_stream(query: str) -> AsyncGenerator[str, None]:
 
92
 
93
  def format_sse(data: dict) -> str:
 
94
  return f"data: {json.dumps(data)}\n\n"
95
 
96
  try:
97
  async with aiohttp.ClientSession() as session:
98
+ # Step 1: Generate Sub-Questions (Unchanged)
99
  yield format_sse({"event": "status", "data": "Generating research plan..."})
100
  sub_question_prompt = {
101
  "model": LLM_MODEL,
102
+ "messages": [{ "role": "user", "content": f"You are a research planner. For the topic '{query}', create a JSON array of 3-4 key sub-questions for a research report. Example: [\"Question 1?\", \"Question 2?\"]" }]
 
 
 
103
  }
104
  async with session.post(LLM_API_URL, headers=LLM_HEADERS, json=sub_question_prompt) as response:
105
  response.raise_for_status()
106
  result = await response.json()
107
+ sub_questions = json.loads(result['choices'][0]['message']['content'])
 
 
 
 
108
 
109
  yield format_sse({"event": "plan", "data": sub_questions})
110
 
111
+ # Step 2: Concurrently research all sub-questions (Unchanged)
112
  research_tasks = [search_and_scrape(session, sq) for sq in sub_questions]
113
  all_research_results = []
114
 
 
121
  yield format_sse({"event": "status", "data": "Consolidating research..."})
122
  full_context = "\n\n---\n\n".join(res[0] for res in all_research_results if res[0])
123
  all_sources = [source for res in all_research_results for source in res[1]]
124
+ unique_sources = list({s['link']: s for s in all_sources}.values())
125
+
126
+ # ***** CHANGE 2: Implement the context truncation safeguard *****
127
+ logger.info(f"Consolidated context size: {len(full_context)} characters.")
128
+ if len(full_context) > MAX_CONTEXT_CHAR_LENGTH:
129
+ logger.warning(f"Context is too long. Truncating from {len(full_context)} to {MAX_CONTEXT_CHAR_LENGTH} characters.")
130
+ full_context = full_context[:MAX_CONTEXT_CHAR_LENGTH]
131
 
132
  if not full_context.strip():
133
  yield format_sse({"event": "error", "data": "Failed to gather any research context."})
 
135
 
136
  # Step 4: Generate the final report with streaming
137
  yield format_sse({"event": "status", "data": "Generating final report..."})
138
+ final_report_prompt = f'Synthesize the provided context into a comprehensive report on "{query}". Use the context exclusively. Structure the report with markdown.\n\n## Research Context ##\n{full_context}'
139
 
140
+ final_report_payload = {"model": LLM_MODEL, "messages": [{"role": "user", "content": final_report_prompt}], "stream": True}
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  async with session.post(LLM_API_URL, headers=LLM_HEADERS, json=final_report_payload) as response:
143
+ # ***** CHANGE 3: More robust error handling for the streaming call *****
144
+ if response.status != 200:
145
+ error_text = await response.text()
146
+ logger.error(f"LLM API returned a non-200 status: {response.status} - {error_text}")
147
+ raise Exception(f"LLM API Error: {response.status}, {error_text}")
148
+
149
  async for line in response.content:
150
+ # (Rest of the streaming logic is the same)
151
  if line.strip():
 
152
  line_str = line.decode('utf-8').strip()
153
+ if line_str.startswith('data:'): line_str = line_str[5:].strip()
154
+ if line_str == "[DONE]": break
 
 
155
  try:
156
  chunk = json.loads(line_str)
157
  content = chunk.get("choices", [{}])[0].get("delta", {}).get("content")
158
+ if content: yield format_sse({"event": "chunk", "data": content})
159
+ except json.JSONDecodeError: continue
 
 
160
 
161
  yield format_sse({"event": "sources", "data": unique_sources})
162
 
 
168
 
169
 
170
  # --- API Endpoints ---
 
 
 
 
 
171
  @app.post("/v1/deepresearch/completions")
172
  async def deep_research_endpoint(request: DeepResearchRequest):
173
+ return StreamingResponse(run_deep_research_stream(request.query), media_type="text/event-stream")