Spaces:
Running
Running
| import os | |
| import sys | |
| import threading | |
| import json | |
| import logging | |
| from typing import Optional, Dict, Any | |
| from fastapi import FastAPI, HTTPException, Header, BackgroundTasks, Depends | |
| from pydantic import BaseModel | |
| from datetime import datetime | |
| # Add src to path | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "..")) | |
| from src.db.local_database import LocalDatabase, DatabaseEntry, DataType | |
| from run_saturday_analysis import run_saturday_analysis | |
| # Configure logging | |
| # logging.basicConfig(level=logging.INFO) # Replaced by handler below | |
| logger = logging.getLogger(__name__) | |
| # --- Logging Capture Setup --- | |
| import collections | |
| # Create a thread-safe buffer for logs | |
| log_buffer = collections.deque(maxlen=2000) | |
| class LogCaptureHandler(logging.Handler): | |
| def emit(self, record): | |
| try: | |
| log_entry = self.format(record) | |
| # Prepend timestamp if not present | |
| # if not log_entry.startswith("20"): | |
| # log_entry = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - {log_entry}" | |
| log_buffer.append(log_entry) | |
| except Exception: | |
| self.handleError(record) | |
| # Setup Root Logger to capture ALL logs (including Coordinator) | |
| root_logger = logging.getLogger() | |
| root_logger.setLevel(logging.INFO) | |
| # Avoid adding duplicate handlers if reloaded | |
| if not any(isinstance(h, LogCaptureHandler) for h in root_logger.handlers): | |
| capture_handler = LogCaptureHandler() | |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| capture_handler.setFormatter(formatter) | |
| root_logger.addHandler(capture_handler) | |
| # Add console handler too if not present | |
| console_handler = logging.StreamHandler(sys.stdout) | |
| console_handler.setFormatter(formatter) | |
| root_logger.addHandler(console_handler) | |
| app = FastAPI(title="Stock Alchemist Signal Generator") | |
| # --- Models --- | |
| class SignalRequest(BaseModel): | |
| ticker: Optional[str] = None | |
| prompt_override: Optional[str] = None | |
| class SignalResponse(BaseModel): | |
| status: str | |
| message: str | |
| signal_id: Optional[str] = None | |
| # --- Dependencies --- | |
| def verify_api_secret(x_api_secret: str = Header(...)): | |
| """Verify the API secret header""" | |
| expected_secret = os.getenv("API_SECRET") | |
| if not expected_secret: | |
| # If no secret is set in env, we might want to fail safe or allow default for dev | |
| # For production security, better to fail if not configured. | |
| logger.warning("API_SECRET environment variable not set! Security disabled.") | |
| return # Allow if env var missing (or raise error based on preference) | |
| # raise HTTPException(status_code=500, detail="Server misconfiguration: API_SECRET not set") | |
| if x_api_secret != expected_secret: | |
| raise HTTPException(status_code=403, detail="Invalid API Secret") | |
| # --- Services --- | |
| def generate_signal_logic(ticker: str, prompt_override: Optional[str] = None): | |
| """ | |
| Core logic to generate a signal using Ollama and save to DB. | |
| """ | |
| import requests | |
| logger.info(f"Generating signal for {ticker}...") | |
| # 1. Construct Prompt | |
| # Need to get some data about the ticker to give to the LLM? | |
| # For now, we'll assume the prompt asks the LLM to use its internal knowledge or just generate a generic signal based on the ticker name. | |
| # In a real scenario, we'd fetch news/price data here and feed it. | |
| # Let's try to fetch some basic info from DB if available? | |
| db = LocalDatabase() | |
| # Construct prompt | |
| prompt = prompt_override or f"Analyze the stock {ticker} and provide a trading signal (BUY/SELL/HOLD) with confidence score and reasoning. Format response as JSON." | |
| try: | |
| # 2. Call Ollama | |
| # Using the local Ollama instance | |
| ollama_url = "http://localhost:11434/api/generate" | |
| payload = { | |
| "model": "llama3.1", | |
| "prompt": prompt, | |
| "stream": False, | |
| "format": "json" # Llama 3 supports json mode often | |
| } | |
| response = requests.post(ollama_url, json=payload, timeout=120) | |
| response.raise_for_status() | |
| result = response.json() | |
| llm_output = result.get('response', '') | |
| logger.info(f"Ollama response for {ticker}: {llm_output[:100]}...") | |
| # 3. Parse and Save to DB | |
| # We'll save the raw LLM output as a signal entry | |
| # Try to parse JSON from LLM if possible, otherwise wrap it | |
| try: | |
| signal_data = json.loads(llm_output) | |
| except json.JSONDecodeError: | |
| signal_data = {"raw_output": llm_output} | |
| # Extract signal position if possible | |
| position = signal_data.get('signal', signal_data.get('recommendation', 'HOLD')).upper() | |
| if position not in ['BUY', 'SELL', 'HOLD']: | |
| position = 'HOLD' # Default | |
| # Save using LocalDatabase | |
| # We need to use save_signal or save generic entry? | |
| # save_signal requires specific keys. Let's use save generic entry or try save_signal if we have the keys. | |
| # simpler to just update the 'signals' table logic in LocalDatabase or use db.save() with DataType.CUSTOM? | |
| # The user's signals table has specific columns. | |
| # local_database.py -> save_signal(self, ticker, calendar_event_keys, news_keys, fundamental_key, signal_position, sentiment) | |
| # We'll provide empty lists for keys for now as we didn't link specific events | |
| is_saved = db.save_signal( | |
| ticker=ticker, | |
| calendar_event_keys=[], | |
| news_keys=[], | |
| fundamental_key="generated_by_ollama", | |
| signal_position=position, | |
| sentiment=signal_data | |
| ) | |
| if is_saved: | |
| logger.info(f"Signal saved for {ticker}") | |
| else: | |
| logger.error(f"Failed to save signal for {ticker}") | |
| except Exception as e: | |
| logger.error(f"Error generating signal for {ticker}: {e}") | |
| # --- Endpoints --- | |
| async def generate_signal(request: SignalRequest, background_tasks: BackgroundTasks): | |
| """ | |
| Trigger signal generation. | |
| If ticker is provided, generates for that ticker. | |
| If not, could pick a random one or all? Let's require ticker for now or pick first available. | |
| """ | |
| target_ticker = request.ticker | |
| if not target_ticker: | |
| # Pick a ticker from DB? | |
| try: | |
| db = LocalDatabase() | |
| tickers = db.get_all_available_tickers() | |
| if tickers: | |
| target_ticker = tickers[0] # Just pick the first one for the demo/daily run | |
| else: | |
| raise HTTPException(status_code=404, detail="No tickers available in database") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Database error: {e}") | |
| # Run in background to avoid timeout | |
| background_tasks.add_task(generate_signal_logic, target_ticker, request.prompt_override) | |
| return SignalResponse( | |
| status="accepted", | |
| message=f"Signal generation started for {target_ticker}" | |
| ) | |
| async def trigger_saturday_analysis(background_tasks: BackgroundTasks): | |
| """ | |
| Trigger the saturday analysis script. | |
| """ | |
| background_tasks.add_task(run_saturday_analysis) | |
| return {"status": "accepted", "message": "Saturday analysis started"} | |
| async def health_check(): | |
| """ | |
| Simple health check. | |
| Also logs vitals as requested. | |
| """ | |
| # Verify DB connection | |
| db_status = "unknown" | |
| try: | |
| db = LocalDatabase() | |
| if db._create_connection(): | |
| db_status = "connected" | |
| else: | |
| db_status = "disconnected" | |
| except Exception as e: | |
| db_status = f"error: {e}" | |
| # Check Ollama | |
| ollama_status = "unknown" | |
| try: | |
| import requests | |
| resp = requests.get("http://localhost:11434/api/tags", timeout=5) | |
| if resp.status_code == 200: | |
| ollama_status = "running" | |
| else: | |
| ollama_status = f"error: {resp.status_code}" | |
| except Exception: | |
| ollama_status = "down" | |
| vitals = { | |
| "status": "ok", | |
| "time": datetime.now().isoformat(), | |
| "database": db_status, | |
| "ollama": ollama_status | |
| } | |
| logger.info(f"Health Check: {vitals}") | |
| return vitals | |
| # --- HTML & Public API --- | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi import Request | |
| # Setup templates | |
| templates = Jinja2Templates(directory="src/templates") | |
| # --- Coordinator Startup --- | |
| async def startup_event(): | |
| """Start the Coordinator (News Scraper + Scheduler) in background""" | |
| import threading | |
| import sys | |
| from pathlib import Path | |
| # Ensure src is in path for coordinator imports | |
| try: | |
| from src.orchestrator.coordinator import Coordinator | |
| logger.info("Initializing Coordinator...") | |
| coordinator = Coordinator() | |
| # Start coordinator in a separate thread because it blocks | |
| coord_thread = threading.Thread(target=coordinator.start, daemon=True) | |
| coord_thread.start() | |
| logger.info("✅ Coordinator started in background thread") | |
| except Exception as e: | |
| logger.error(f"❌ Failed to start Coordinator: {e}") | |
| # ... existing code ... | |
| async def get_signals(): | |
| """Get recent signals (Public Read-Only)""" | |
| try: | |
| db = LocalDatabase() | |
| signals = db.get_recent_signals(limit=50) | |
| return signals | |
| except Exception as e: | |
| logger.error(f"Error fetching signals: {e}") | |
| raise HTTPException(status_code=500, detail="Database error") | |
| async def test_db_connection(): | |
| """ | |
| Detailed database connection test | |
| """ | |
| results = { | |
| "status": "failed", | |
| "details": [], | |
| "config": {} | |
| } | |
| try: | |
| # Check env vars (redact password) | |
| import os | |
| results["config"] = { | |
| "host": os.getenv('DB_HOST'), | |
| "port": os.getenv('DB_PORT'), | |
| "user": os.getenv('DB_USERNAME'), | |
| "database": os.getenv('DB_DATABASE'), | |
| "ssl_ca_set": bool(os.getenv('DB_SSL_CA')) | |
| } | |
| db = LocalDatabase() | |
| conn = db._create_connection() | |
| if conn and conn.is_connected(): | |
| results["status"] = "success" | |
| results["details"].append("Connection successful") | |
| results["details"].append(f"Server Info: {conn.get_server_info()}") | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT VERSION()") | |
| version = cursor.fetchone() | |
| results["details"].append(f"DB Version: {version[0]}") | |
| conn.close() | |
| else: | |
| results["details"].append("Connection object create but is_connected() returned False") | |
| except Exception as e: | |
| results["details"].append(f"Exception: {str(e)}") | |
| import traceback | |
| results["traceback"] = traceback.format_exc() | |
| return results | |
| async def test_ollama_connection(): | |
| """ | |
| Test Ollama connectivity and model status | |
| """ | |
| import requests | |
| results = { | |
| "status": "failed", | |
| "details": [], | |
| "model_found": False | |
| } | |
| try: | |
| # 1. Check if Service is Up | |
| base_url = "http://localhost:11434" | |
| try: | |
| resp = requests.get(f"{base_url}/api/tags", timeout=5) | |
| if resp.status_code == 200: | |
| results["details"].append("Ollama Service is UP") | |
| models = resp.json().get('models', []) | |
| model_names = [m.get('name') for m in models] | |
| results["details"].append(f"Available Models: {', '.join(model_names)}") | |
| if any("llama3.1" in m for m in model_names): | |
| results["model_found"] = True | |
| else: | |
| results["details"].append("WARNING: llama3.1 model not found in list!") | |
| else: | |
| results["details"].append(f"Service returned status {resp.status_code}") | |
| return results | |
| except Exception as e: | |
| results["details"].append(f"Failed to connect to Ollama Service: {e}") | |
| return results | |
| # 2. Test Generation (if service is up) | |
| if results["model_found"]: | |
| try: | |
| payload = { | |
| "model": "llama3.1", | |
| "prompt": "hi", | |
| "stream": False | |
| } | |
| resp = requests.post(f"{base_url}/api/generate", json=payload, timeout=10) | |
| if resp.status_code == 200: | |
| ans = resp.json().get('response', '') | |
| results["details"].append(f"Generation Test Pass: '{ans[:20]}...'") | |
| results["status"] = "success" | |
| else: | |
| results["details"].append(f"Generation Failed: {resp.text}") | |
| except Exception as e: | |
| results["details"].append(f"Generation Error: {e}") | |
| else: | |
| results["details"].append("Skipping generation test as model not found.") | |
| except Exception as e: | |
| results["details"].append(f"Unexpected Error: {e}") | |
| return results | |
| async def root(request: Request): | |
| """ | |
| Serve the Home Screen Dashboard | |
| """ | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def view_logs(request: Request): | |
| """Serve the Logs Page""" | |
| return templates.TemplateResponse("logs.html", {"request": request}) | |
| async def get_logs(): | |
| """Get recent logs""" | |
| return {"logs": list(log_buffer)} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |