Spaces:
Running
Running
| """ | |
| Advanced Local Database Manager for Stock Alchemist | |
| Stores data as JSON files organized by date, type, and ticker | |
| Uses MySQL for indexing and JSON files for data storage | |
| """ | |
| import json | |
| import mysql.connector | |
| from mysql.connector import Error | |
| from datetime import date, datetime, timedelta | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Union | |
| from dataclasses import dataclass, asdict, field | |
| from enum import Enum | |
| import hashlib | |
| import gzip | |
| import numpy as np | |
| class DataType(Enum): | |
| """Supported data types - simplified naming""" | |
| # Calendar events (no prefix needed) | |
| EARNINGS = "earnings" | |
| IPO = "ipo" | |
| STOCK_SPLIT = "stock_split" | |
| DIVIDENDS = "dividends" | |
| ECONOMIC_EVENTS = "economic_events" | |
| # Other data types | |
| FUNDAMENTAL = "fundamental_analysis" | |
| NEWS = "news" | |
| TECHNICAL_ANALYSIS = "technical_analysis" | |
| class DatabaseEntry: | |
| """Base class for database entries""" | |
| date: str # ISO format YYYY-MM-DD | |
| data_type: str # DataType enum value | |
| ticker: str | |
| data: Dict[str, Any] | |
| created_at: str = field(default_factory=lambda: datetime.now().isoformat()) | |
| updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) | |
| expiry_date: Optional[str] = None | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| def to_dict(self): | |
| """Convert to dictionary""" | |
| return asdict(self) | |
| def from_dict(cls, data: Dict): | |
| """Create from dictionary""" | |
| return cls(**data) | |
| def generate_key(self): | |
| """ | |
| Generate unique key for this entry | |
| For calendar events, includes execution_date/ex_date to prevent duplicates | |
| """ | |
| # For calendar events, include the actual event date to ensure uniqueness | |
| if self.data_type in ['earnings', 'ipo', 'stock_split', 'dividends']: | |
| event_date = (self.data.get('execution_date') or | |
| self.data.get('ex_date') or | |
| self.data.get('date') or | |
| self.date) | |
| key_string = f"{self.data_type}_{self.ticker}_{event_date}" | |
| else: | |
| key_string = f"{self.date}_{self.data_type}_{self.ticker}" | |
| return hashlib.md5(key_string.encode()).hexdigest() | |
| class LocalDatabase: | |
| """ | |
| Advanced local database manager with MySQL index and JSON file storage | |
| Features: | |
| - MySQL metadata index for fast queries | |
| - JSON files for actual data storage | |
| - Optional gzip compression for large data | |
| - Automatic expiry and cleanup | |
| - Date/Type/Ticker indexing | |
| - Batch operations support | |
| """ | |
| def __init__(self, db_dir: str = "database", compress: bool = False): | |
| """ | |
| Initialize database manager | |
| Args: | |
| db_dir: Root directory for database storage | |
| compress: Whether to compress JSON files with gzip | |
| """ | |
| self.db_dir = Path(db_dir) | |
| try: | |
| self.db_dir.mkdir(exist_ok=True) | |
| self.data_dir = self.db_dir / "data" | |
| self.data_dir.mkdir(exist_ok=True) | |
| except (PermissionError, OSError) as e: | |
| print(f"⚠️ Error creating database dir at {self.db_dir}: {e}") | |
| import tempfile | |
| self.db_dir = Path(tempfile.gettempdir()) / "gotti_database" | |
| print(f"⚠️ Falling back to temporary directory: {self.db_dir}") | |
| self.db_dir.mkdir(exist_ok=True) | |
| self.data_dir = self.db_dir / "data" | |
| self.data_dir.mkdir(exist_ok=True) | |
| # Load environment variables | |
| from dotenv import load_dotenv | |
| import os | |
| load_dotenv() | |
| # MySQL connection parameters from environment variables | |
| self.mysql_config = { | |
| 'host': os.getenv('DB_HOST', 'localhost').strip(), | |
| 'user': os.getenv('DB_USERNAME', 'root').strip(), | |
| 'password': os.getenv('DB_PASSWORD', '').strip(), | |
| 'database': os.getenv('DB_DATABASE', 'gotti').strip(), | |
| 'port': int(os.getenv('DB_PORT', 3306)) | |
| } | |
| # SSL Configuration for TiDB | |
| ssl_ca = os.getenv('DB_SSL_CA') | |
| if ssl_ca: | |
| # Handle if the content itself is passed (begins with ---) | |
| if "-----BEGIN CERTIFICATE-----" in ssl_ca: | |
| try: | |
| import tempfile | |
| # Create a temporary file for the certificate | |
| # We use a fixed location in /tmp relative to workdir if possible or tempfile | |
| # But tempfile is safer. We need it to persist for the process life. | |
| # Be careful about file cleanup, but for a container it's fine. | |
| # Create a persistent temp file (won't be deleted automatically on close) | |
| # We'll save it to a known location for debugging: /tmp/tidb_ca.pem | |
| tmp_ca_path = Path("/tmp/tidb_ca.pem") | |
| # If on Windows (local dev), assume execution in app dir | |
| if os.name == 'nt': | |
| tmp_ca_path = Path("tidb_ca.pem") | |
| with open(tmp_ca_path, "w", encoding='utf-8') as f: | |
| f.write(ssl_ca) | |
| self.mysql_config['ssl_ca'] = str(tmp_ca_path) | |
| self.mysql_config['ssl_verify_cert'] = True | |
| self.mysql_config['ssl_verify_identity'] = True | |
| print(f"🔒 Using provided SSL Certificate content (saved to {tmp_ca_path})") | |
| except Exception as e: | |
| print(f"⚠️ Failed to write SSL CA content: {e}") | |
| else: | |
| # Resolve relative path if needed (existing logic) | |
| if not os.path.isabs(ssl_ca): | |
| project_root = Path(__file__).parent.parent.parent | |
| ssl_ca_path = project_root / ssl_ca | |
| if ssl_ca_path.exists(): | |
| self.mysql_config['ssl_ca'] = str(ssl_ca_path) | |
| self.mysql_config['ssl_verify_cert'] = True | |
| self.mysql_config['ssl_verify_identity'] = True | |
| print(f"🔒 Using SSL Certificate from file: {ssl_ca_path}") | |
| else: | |
| print(f"⚠️ SSL CA file not found at {ssl_ca_path}") | |
| elif os.path.exists(ssl_ca): | |
| self.mysql_config['ssl_ca'] = ssl_ca | |
| self.mysql_config['ssl_verify_cert'] = True | |
| self.mysql_config['ssl_verify_identity'] = True | |
| self.compress = compress | |
| self._init_database() | |
| def _create_connection(self): | |
| """Create and return a MySQL database connection""" | |
| try: | |
| connection = mysql.connector.connect(**self.mysql_config) | |
| return connection | |
| except Error as e: | |
| print(f"❌ Error connecting to MySQL: {e}") | |
| return None | |
| def _get_table_name(self, data_type: str) -> str: | |
| """Determine which table to use based on data_type""" | |
| # Calendar events | |
| if data_type in ['earnings', 'ipo', 'stock_split', 'dividends', 'economic_events']: | |
| return 'calendar' | |
| # News | |
| elif data_type == 'news': | |
| return 'news' | |
| # Fundamental analysis | |
| elif data_type == 'fundamental_analysis': | |
| return 'fundamental_analysis' | |
| else: | |
| raise ValueError(f"Unknown data type: {data_type}") | |
| def _init_database(self): | |
| """Initialize MySQL tables - three separate tables by data category""" | |
| conn = self._create_connection() | |
| if not conn: | |
| raise Exception("Failed to connect to MySQL database") | |
| cursor = conn.cursor() | |
| try: | |
| # Create calendar table | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS calendar ( | |
| entry_key VARCHAR(32) PRIMARY KEY, | |
| date DATE NOT NULL, | |
| event_type VARCHAR(50) NOT NULL, | |
| ticker VARCHAR(20) NOT NULL, | |
| data JSON NOT NULL, | |
| created_at DATETIME NOT NULL, | |
| updated_at DATETIME NOT NULL, | |
| expiry_date DATE, | |
| metadata JSON, | |
| execution_date DATE, | |
| INDEX idx_date (date), | |
| INDEX idx_event_type (event_type), | |
| INDEX idx_ticker (ticker), | |
| INDEX idx_date_event (date, event_type), | |
| INDEX idx_ticker_event (ticker, event_type) | |
| ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 | |
| ''') | |
| # Create news table | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS news ( | |
| entry_key VARCHAR(32) PRIMARY KEY, | |
| date DATE NOT NULL, | |
| ticker VARCHAR(20) NOT NULL, | |
| data JSON NOT NULL, | |
| created_at DATETIME NOT NULL, | |
| updated_at DATETIME NOT NULL, | |
| expiry_date DATE, | |
| metadata JSON, | |
| INDEX idx_date (date), | |
| INDEX idx_ticker (ticker), | |
| INDEX idx_date_ticker (date, ticker) | |
| ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 | |
| ''') | |
| # Create fundamental_analysis table | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS fundamental_analysis ( | |
| entry_key VARCHAR(32) PRIMARY KEY, | |
| date DATE NOT NULL, | |
| ticker VARCHAR(20) NOT NULL, | |
| data JSON NOT NULL, | |
| created_at DATETIME NOT NULL, | |
| updated_at DATETIME NOT NULL, | |
| expiry_date DATE, | |
| metadata JSON, | |
| INDEX idx_date (date), | |
| INDEX idx_ticker (ticker), | |
| INDEX idx_date_ticker (date, ticker) | |
| ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 | |
| ''') | |
| # Create available_tickers table - whitelist of allowed tickers | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS available_tickers ( | |
| ticker VARCHAR(20) PRIMARY KEY, | |
| name VARCHAR(255), | |
| exchange VARCHAR(50), | |
| sector VARCHAR(100), | |
| is_active BOOLEAN DEFAULT TRUE, | |
| added_at DATETIME NOT NULL, | |
| updated_at DATETIME NOT NULL, | |
| metadata JSON, | |
| INDEX idx_is_active (is_active), | |
| INDEX idx_exchange (exchange), | |
| INDEX idx_sector (sector) | |
| ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 | |
| ''') | |
| # Create signals table - tracks actionable ticker signals | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS signals ( | |
| signal_id VARCHAR(32) PRIMARY KEY, | |
| ticker VARCHAR(20) NOT NULL, | |
| signal_date DATE NOT NULL, | |
| signal_position VARCHAR(10) NOT NULL, | |
| calendar_event_keys JSON, | |
| news_keys JSON, | |
| fundamental_analysis_key VARCHAR(32), | |
| sentiment JSON, | |
| created_at DATETIME NOT NULL, | |
| updated_at DATETIME NOT NULL, | |
| metadata JSON, | |
| INDEX idx_ticker (ticker), | |
| INDEX idx_signal_date (signal_date), | |
| INDEX idx_ticker_date (ticker, signal_date), | |
| INDEX idx_created_at (created_at) | |
| ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 | |
| ''') | |
| # Add execution_date column to calendar table if it doesn't exist | |
| try: | |
| cursor.execute("ALTER TABLE calendar ADD COLUMN execution_date DATE AFTER date") | |
| print("✅ Added execution_date column to calendar table") | |
| except Error as e: | |
| if e.errno != 1060: # Error 1060 = Duplicate column name | |
| pass # Column already exists, ignore | |
| # Add sentiment column to signals table if it doesn't exist | |
| try: | |
| cursor.execute("ALTER TABLE signals ADD COLUMN sentiment JSON AFTER fundamental_analysis_key") | |
| print("✅ Added sentiment column to signals table") | |
| except Error as e: | |
| if e.errno != 1060: # Error 1060 = Duplicate column name | |
| pass # Column already exists, ignore | |
| # Add signal_position column to signals table if it doesn't exist | |
| try: | |
| cursor.execute("ALTER TABLE signals ADD COLUMN signal_position VARCHAR(10) AFTER signal_date") | |
| print("✅ Added signal_position column to signals table") | |
| except Error as e: | |
| if e.errno != 1060: # Error 1060 = Duplicate column name | |
| pass # Column already exists, ignore | |
| conn.commit() | |
| print("✅ MySQL database tables initialized successfully (calendar, news, fundamental_analysis, available_tickers, signals)") | |
| except Error as e: | |
| print(f"❌ Error initializing database: {e}") | |
| # Try without IF NOT EXISTS for MySQL versions that don't support it | |
| try: | |
| cursor.execute("SHOW COLUMNS FROM calendar LIKE 'execution_date'") | |
| if cursor.fetchone() is None: | |
| cursor.execute("ALTER TABLE calendar ADD COLUMN execution_date DATE AFTER date") | |
| conn.commit() | |
| print("✅ Added execution_date column to calendar table") | |
| except Exception as alter_error: | |
| print(f"⚠️ Could not add execution_date column: {alter_error}") | |
| # Try adding sentiment column for older MySQL versions | |
| try: | |
| cursor.execute("SHOW COLUMNS FROM signals LIKE 'sentiment'") | |
| if cursor.fetchone() is None: | |
| cursor.execute("ALTER TABLE signals ADD COLUMN sentiment JSON AFTER fundamental_analysis_key") | |
| conn.commit() | |
| print("✅ Added sentiment column to signals table") | |
| except Exception as alter_error: | |
| print(f"⚠️ Could not add sentiment column: {alter_error}") | |
| # Try adding signal_position column for older MySQL versions | |
| try: | |
| cursor.execute("SHOW COLUMNS FROM signals LIKE 'signal_position'") | |
| if cursor.fetchone() is None: | |
| cursor.execute("ALTER TABLE signals ADD COLUMN signal_position VARCHAR(10) AFTER signal_date") | |
| conn.commit() | |
| print("✅ Added signal_position column to signals table") | |
| except Exception as alter_error: | |
| print(f"⚠️ Could not add signal_position column: {alter_error}") | |
| finally: | |
| cursor.close() | |
| conn.close() | |
| def _generate_file_path(self, entry_key: str, data_type: str, date_str: str) -> Path: | |
| """Generate organized file path for data storage""" | |
| # Organize by type/year/month/ | |
| year_month = datetime.fromisoformat(date_str).strftime("%Y/%m") | |
| type_dir = self.data_dir / data_type / year_month | |
| type_dir.mkdir(parents=True, exist_ok=True) | |
| extension = ".json.gz" if self.compress else ".json" | |
| return type_dir / f"{entry_key}{extension}" | |
| def _write_json(self, file_path: Path, data: Dict): | |
| """Write JSON data with optional compression""" | |
| json_str = json.dumps(data, indent=2, default=str) | |
| if self.compress: | |
| with gzip.open(file_path, 'wt', encoding='utf-8') as f: | |
| f.write(json_str) | |
| else: | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| f.write(json_str) | |
| def _read_json(self, file_path: Path, compressed: bool) -> Dict: | |
| """Read JSON data with optional decompression""" | |
| if compressed: | |
| with gzip.open(file_path, 'rt', encoding='utf-8') as f: | |
| return json.load(f) | |
| else: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| def _clean_data_for_json(self, data: Any) -> Any: | |
| """ | |
| Recursively clean data to ensure it's JSON serializable. | |
| - Converts NaN/Infinity to None | |
| - Converts numpy types to native python types | |
| """ | |
| if isinstance(data, dict): | |
| return {k: self._clean_data_for_json(v) for k, v in data.items()} | |
| elif isinstance(data, list): | |
| return [self._clean_data_for_json(v) for v in data] | |
| elif isinstance(data, float): | |
| if np.isnan(data) or np.isinf(data): | |
| return None | |
| return float(data) | |
| elif isinstance(data, np.integer): | |
| return int(data) | |
| elif isinstance(data, np.floating): | |
| if np.isnan(data) or np.isinf(data): | |
| return None | |
| return float(data) | |
| elif isinstance(data, np.ndarray): | |
| return self._clean_data_for_json(data.tolist()) | |
| return data | |
| def add_ticker(self, ticker: str, name: str = None, exchange: str = None, | |
| sector: str = None, metadata: Dict = None) -> bool: | |
| """ | |
| Add a ticker to the available_tickers whitelist. | |
| Args: | |
| ticker: Ticker symbol | |
| name: Company name | |
| exchange: Exchange name (e.g., 'NASDAQ', 'NYSE') | |
| sector: Company sector | |
| metadata: Additional metadata as JSON | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| conn = self._create_connection() | |
| if not conn: | |
| return False | |
| cursor = conn.cursor() | |
| try: | |
| now = datetime.now() | |
| cursor.execute(''' | |
| INSERT INTO available_tickers | |
| (ticker, name, exchange, sector, is_active, added_at, updated_at, metadata) | |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s) | |
| ON DUPLICATE KEY UPDATE | |
| name = VALUES(name), | |
| exchange = VALUES(exchange), | |
| sector = VALUES(sector), | |
| is_active = VALUES(is_active), | |
| updated_at = VALUES(updated_at), | |
| metadata = VALUES(metadata) | |
| ''', ( | |
| ticker.upper(), | |
| name, | |
| exchange, | |
| sector, | |
| True, | |
| now, | |
| now, | |
| json.dumps(metadata) if metadata else None | |
| )) | |
| conn.commit() | |
| return True | |
| except Error as e: | |
| print(f"❌ Error adding ticker {ticker}: {e}") | |
| return False | |
| finally: | |
| cursor.close() | |
| conn.close() | |
| def get_macroeconomic_indicators(self) -> Dict[str, Any]: | |
| """ | |
| Retrieve macroeconomic indicators from the database. | |
| Returns: | |
| Dictionary of macroeconomic indicators | |
| """ | |
| conn = self._create_connection() | |
| if not conn: | |
| return {} | |
| cursor = conn.cursor() | |
| try: | |
| cursor.execute(''' | |
| SELECT data | |
| FROM macroeconomic_indicators | |
| ORDER BY date DESC | |
| LIMIT 1 | |
| ''') | |
| row = cursor.fetchone() | |
| if row: | |
| return json.loads(row[0]) | |
| return {} | |
| except Error as e: | |
| print(f"❌ Error fetching macroeconomic indicators: {e}") | |
| return {} | |
| finally: | |
| cursor.close() | |
| conn.close() | |
| def remove_ticker(self, ticker: str) -> bool: | |
| """ | |
| Deactivate a ticker (soft delete - sets is_active to False). | |
| Args: | |
| ticker: Ticker symbol to deactivate | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| conn = self._create_connection() | |
| if not conn: | |
| return False | |
| cursor = conn.cursor() | |
| try: | |
| cursor.execute( | |
| "UPDATE available_tickers SET is_active = FALSE, updated_at = %s WHERE ticker = %s", | |
| (datetime.now(), ticker.upper()) | |
| ) | |
| conn.commit() | |
| return cursor.rowcount > 0 | |
| except Error as e: | |
| print(f"❌ Error removing ticker {ticker}: {e}") | |
| return False | |
| finally: | |
| cursor.close() | |
| conn.close() | |
| def get_all_available_tickers(self) -> List[str]: | |
| """ | |
| Get all active tickers from the whitelist. | |
| Returns: | |
| List of ticker symbols | |
| """ | |
| conn = self._create_connection() | |
| if not conn: | |
| return [] | |
| cursor = conn.cursor() | |
| try: | |
| cursor.execute("SELECT ticker FROM available_tickers WHERE is_active = TRUE ORDER BY ticker") | |
| return [row[0] for row in cursor.fetchall()] | |
| except Error as e: | |
| print(f"❌ Error fetching available tickers: {e}") | |
| return [] | |
| finally: | |
| cursor.close() | |
| conn.close() | |
| def is_ticker_available(self, ticker: str) -> bool: | |
| """ | |
| Check if ticker is in the available_tickers whitelist. | |
| Args: | |
| ticker: Ticker symbol to check | |
| Returns: | |
| True if ticker is available and active, False otherwise | |
| """ | |
| conn = self._create_connection() | |
| if not conn: | |
| return False | |
| cursor = conn.cursor() | |
| try: | |
| cursor.execute( | |
| "SELECT is_active FROM available_tickers WHERE ticker = %s", | |
| (ticker.upper(),) | |
| ) | |
| result = cursor.fetchone() | |
| if result and result[0]: # Ticker exists and is_active = True | |
| return True | |
| return False | |
| except Error as e: | |
| print(f"❌ Error checking ticker availability for {ticker}: {e}") | |
| return False | |
| finally: | |
| cursor.close() | |
| conn.close() | |
| def save(self, entry: DatabaseEntry, expiry_days: Optional[int] = None) -> bool: | |
| """ | |
| Save entry to database. Updates existing entry if duplicate is found. | |
| IMPORTANT: Checks if ticker is in available_tickers whitelist before saving. | |
| Args: | |
| entry: DatabaseEntry to save | |
| expiry_days: Optional expiry in days | |
| Returns: | |
| True if successful, False if ticker not available or save fails | |
| """ | |
| try: | |
| # CRITICAL: Check if ticker is in the available_tickers whitelist | |
| # Skip check for economic events as they use country names as tickers | |
| if entry.data_type != DataType.ECONOMIC_EVENTS.value and not self.is_ticker_available(entry.ticker): | |
| print(f"⚠️ Skipping {entry.data_type} for {entry.ticker} - ticker not in available_tickers whitelist") | |
| return False | |
| entry_key = entry.generate_key() | |
| # Get the appropriate table name | |
| table_name = self._get_table_name(entry.data_type) | |
| # Check if entry already exists | |
| conn = self._create_connection() | |
| if not conn: | |
| return False | |
| cursor = conn.cursor() | |
| cursor.execute(f'SELECT created_at FROM {table_name} WHERE entry_key = %s', (entry_key,)) | |
| existing = cursor.fetchone() | |
| # Preserve original created_at if updating | |
| if existing: | |
| entry.created_at = str(existing[0]) | |
| # Update the updated_at timestamp | |
| entry.updated_at = datetime.now().isoformat() | |
| # Calculate expiry date if specified | |
| if expiry_days: | |
| expiry_date = (datetime.now() + timedelta(days=expiry_days)).date().isoformat() | |
| entry.expiry_date = expiry_date | |
| # Store data directly as JSON in database | |
| # Different INSERT statement based on table structure | |
| if table_name == 'calendar': | |
| # Calendar table has event_type column | |
| cursor.execute(''' | |
| INSERT INTO calendar | |
| (entry_key, date, event_type, ticker, data, created_at, updated_at, | |
| expiry_date, metadata) | |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) | |
| ON DUPLICATE KEY UPDATE | |
| data = VALUES(data), | |
| updated_at = VALUES(updated_at), | |
| expiry_date = VALUES(expiry_date), | |
| metadata = VALUES(metadata) | |
| ''', ( | |
| entry_key, | |
| entry.date, | |
| entry.data_type, # event_type (earnings, ipo, etc.) | |
| entry.ticker, | |
| json.dumps(self._clean_data_for_json(entry.data), default=str), | |
| entry.created_at, | |
| entry.updated_at, | |
| entry.expiry_date, | |
| json.dumps(entry.metadata) | |
| )) | |
| else: | |
| # News and fundamental_analysis tables don't have event_type | |
| cursor.execute(f''' | |
| INSERT INTO {table_name} | |
| (entry_key, date, ticker, data, created_at, updated_at, | |
| expiry_date, metadata) | |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s) | |
| ON DUPLICATE KEY UPDATE | |
| data = VALUES(data), | |
| updated_at = VALUES(updated_at), | |
| expiry_date = VALUES(expiry_date), | |
| metadata = VALUES(metadata) | |
| ''', ( | |
| entry_key, | |
| entry.date, | |
| entry.ticker, | |
| json.dumps(self._clean_data_for_json(entry.data), default=str), | |
| entry.created_at, | |
| entry.updated_at, | |
| entry.expiry_date, | |
| json.dumps(entry.metadata) | |
| )) | |
| conn.commit() | |
| return True | |
| except Exception as e: | |
| print(f"Error saving entry {entry.ticker}: {e}") | |
| return False | |
| finally: | |
| if 'cursor' in locals(): | |
| cursor.close() | |
| if 'conn' in locals() and conn: | |
| conn.close() | |
| def save_batch(self, entries: List[DatabaseEntry], expiry_days: Optional[int] = None) -> int: | |
| """ | |
| Save multiple entries in batch. Updates existing entries if duplicates are found. | |
| Args: | |
| entries: List of DatabaseEntry objects | |
| expiry_days: Optional expiry in days | |
| Returns: | |
| Number of successfully saved entries | |
| """ | |
| success_count = 0 | |
| conn = self._create_connection() | |
| if not conn: | |
| return 0 | |
| cursor = conn.cursor() | |
| try: | |
| for entry in entries: | |
| try: | |
| # CRITICAL: Check if ticker is in the available_tickers whitelist | |
| # Skip check for economic events as they use country names as tickers | |
| if entry.data_type != DataType.ECONOMIC_EVENTS.value and not self.is_ticker_available(entry.ticker): | |
| print(f"⚠️ Skipping {entry.data_type} for {entry.ticker} - ticker not in available_tickers whitelist") | |
| continue | |
| entry_key = entry.generate_key() | |
| table_name = self._get_table_name(entry.data_type) | |
| # Check if entry already exists | |
| cursor.execute(f'SELECT created_at FROM {table_name} WHERE entry_key = %s', (entry_key,)) | |
| existing = cursor.fetchone() | |
| # Preserve original created_at if updating | |
| if existing: | |
| entry.created_at = str(existing[0]) | |
| # Update the updated_at timestamp | |
| entry.updated_at = datetime.now().isoformat() | |
| if expiry_days: | |
| expiry_date = (datetime.now() + timedelta(days=expiry_days)).date().isoformat() | |
| entry.expiry_date = expiry_date | |
| # Store data directly in database - different format for calendar | |
| if table_name == 'calendar': | |
| cursor.execute(''' | |
| INSERT INTO calendar | |
| (entry_key, date, event_type, ticker, data, created_at, updated_at, | |
| expiry_date, metadata) | |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) | |
| ON DUPLICATE KEY UPDATE | |
| data = VALUES(data), | |
| updated_at = VALUES(updated_at), | |
| expiry_date = VALUES(expiry_date), | |
| metadata = VALUES(metadata) | |
| ''', ( | |
| entry_key, | |
| entry.date, | |
| entry.data_type, # event_type | |
| entry.ticker, | |
| json.dumps(self._clean_data_for_json(entry.data), default=str), | |
| entry.created_at, | |
| entry.updated_at, | |
| entry.expiry_date, | |
| json.dumps(entry.metadata) | |
| )) | |
| else: | |
| cursor.execute(f''' | |
| INSERT INTO {table_name} | |
| (entry_key, date, ticker, data, created_at, updated_at, | |
| expiry_date, metadata) | |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s) | |
| ON DUPLICATE KEY UPDATE | |
| data = VALUES(data), | |
| updated_at = VALUES(updated_at), | |
| expiry_date = VALUES(expiry_date), | |
| metadata = VALUES(metadata) | |
| ''', ( | |
| entry_key, | |
| entry.date, | |
| entry.ticker, | |
| json.dumps(self._clean_data_for_json(entry.data), default=str), | |
| entry.created_at, | |
| entry.updated_at, | |
| entry.expiry_date, | |
| json.dumps(entry.metadata) | |
| )) | |
| success_count += 1 | |
| except Exception as e: | |
| print(f"Error saving entry {entry.ticker}: {e}") | |
| continue | |
| conn.commit() | |
| finally: | |
| cursor.close() | |
| conn.close() | |
| return success_count | |
| def save_signal(self, ticker: str, calendar_event_keys: List[str], news_keys: List[str], | |
| fundamental_key: str, signal_position: str, sentiment: Dict = None) -> bool: | |
| """Save a generated signal to the database""" | |
| try: | |
| conn = self._create_connection() | |
| if not conn: | |
| return False | |
| cursor = conn.cursor() | |
| signal_date = datetime.now().date().isoformat() | |
| signal_id = hashlib.md5(f"{ticker}_{signal_date}".encode()).hexdigest() | |
| now = datetime.now() | |
| # Merge with existing sentiment if provided | |
| final_sentiment = sentiment | |
| if sentiment: | |
| cursor.execute("SELECT sentiment FROM signals WHERE signal_id = %s", (signal_id,)) | |
| existing = cursor.fetchone() | |
| if existing and existing[0]: | |
| existing_sentiment = json.loads(existing[0]) if isinstance(existing[0], str) else existing[0] | |
| if isinstance(existing_sentiment, dict): | |
| existing_sentiment.update(sentiment) | |
| final_sentiment = existing_sentiment | |
| cursor.execute(''' | |
| INSERT INTO signals | |
| (signal_id, ticker, signal_date, signal_position, calendar_event_keys, news_keys, | |
| fundamental_analysis_key, sentiment, created_at, updated_at, metadata) | |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) | |
| ON DUPLICATE KEY UPDATE | |
| signal_position = VALUES(signal_position), | |
| calendar_event_keys = VALUES(calendar_event_keys), | |
| news_keys = VALUES(news_keys), | |
| fundamental_analysis_key = VALUES(fundamental_analysis_key), | |
| sentiment = VALUES(sentiment), | |
| updated_at = VALUES(updated_at) | |
| ''', ( | |
| signal_id, | |
| ticker, | |
| signal_date, | |
| signal_position, | |
| json.dumps(calendar_event_keys), | |
| json.dumps(news_keys), | |
| fundamental_key, | |
| json.dumps(final_sentiment) if final_sentiment else None, | |
| now, | |
| now, | |
| json.dumps({}) | |
| )) | |
| conn.commit() | |
| return True | |
| except Error as e: | |
| print(f"❌ Error saving signal: {e}") | |
| return False | |
| finally: | |
| if 'cursor' in locals(): | |
| cursor.close() | |
| if 'conn' in locals(): | |
| conn.close() | |
| def get_signal(self, ticker: str, date_str: str = None) -> Optional[Dict]: | |
| """Get signal for a ticker on a specific date (defaults to today)""" | |
| try: | |
| if not date_str: | |
| date_str = datetime.now().date().isoformat() | |
| conn = self._create_connection() | |
| if not conn: | |
| return None | |
| cursor = conn.cursor(dictionary=True) | |
| cursor.execute(''' | |
| SELECT * FROM signals | |
| WHERE ticker = %s AND signal_date = %s | |
| ''', (ticker, date_str)) | |
| result = cursor.fetchone() | |
| if result: | |
| # Parse JSON fields | |
| for field in ['calendar_event_keys', 'news_keys', 'metadata', 'sentiment']: | |
| if result.get(field): | |
| if isinstance(result[field], str): | |
| result[field] = json.loads(result[field]) | |
| return result | |
| except Error as e: | |
| print(f"❌ Error getting signal: {e}") | |
| return None | |
| finally: | |
| if 'cursor' in locals(): | |
| cursor.close() | |
| if 'conn' in locals(): | |
| conn.close() | |
| def get_recent_signals(self, limit: int = 50) -> List[Dict]: | |
| """Get most recent signals""" | |
| try: | |
| conn = self._create_connection() | |
| if not conn: | |
| return [] | |
| cursor = conn.cursor(dictionary=True) | |
| cursor.execute(''' | |
| SELECT * FROM signals | |
| ORDER BY created_at DESC | |
| LIMIT %s | |
| ''', (limit,)) | |
| results = cursor.fetchall() | |
| parsed_results = [] | |
| for result in results: | |
| # Parse JSON fields | |
| for field in ['calendar_event_keys', 'news_keys', 'metadata', 'sentiment']: | |
| if result.get(field) and isinstance(result[field], str): | |
| try: | |
| result[field] = json.loads(result[field]) | |
| except: | |
| pass | |
| parsed_results.append(result) | |
| return parsed_results | |
| except Error as e: | |
| print(f"❌ Error getting recent signals: {e}") | |
| return [] | |
| finally: | |
| if 'cursor' in locals(): | |
| cursor.close() | |
| if 'conn' in locals(): | |
| conn.close() | |
| def get(self, date_str: str, data_type: str, ticker: str) -> Optional[DatabaseEntry]: | |
| """ | |
| Retrieve entry by date, type, and ticker | |
| Args: | |
| date_str: Date in YYYY-MM-DD format | |
| data_type: Data type (earnings, ipo, news, fundamental_analysis, etc.) | |
| ticker: Stock ticker | |
| Returns: | |
| DatabaseEntry if found, None otherwise | |
| """ | |
| try: | |
| # Generate key | |
| key_string = f"{date_str}_{data_type}_{ticker}" | |
| entry_key = hashlib.md5(key_string.encode()).hexdigest() | |
| # Get table name for this data type | |
| table_name = self._get_table_name(data_type) | |
| conn = self._create_connection() | |
| if not conn: | |
| return None | |
| cursor = conn.cursor() | |
| # Different SELECT based on table structure | |
| if table_name == 'calendar': | |
| cursor.execute(''' | |
| SELECT date, event_type, ticker, data, created_at, updated_at, expiry_date, metadata | |
| FROM calendar | |
| WHERE entry_key = %s | |
| ''', (entry_key,)) | |
| else: | |
| cursor.execute(f''' | |
| SELECT date, ticker, data, created_at, updated_at, expiry_date, metadata | |
| FROM {table_name} | |
| WHERE entry_key = %s | |
| ''', (entry_key,)) | |
| result = cursor.fetchone() | |
| cursor.close() | |
| conn.close() | |
| if not result: | |
| return None | |
| # Parse result based on table structure | |
| if table_name == 'calendar': | |
| date_val, event_type_val, ticker_val, data_json, created_at, updated_at, expiry_date, metadata_json = result | |
| data_type_val = event_type_val # event_type is the data_type | |
| else: | |
| date_val, ticker_val, data_json, created_at, updated_at, expiry_date, metadata_json = result | |
| data_type_val = data_type # Use the data_type parameter | |
| # Check if expired | |
| if expiry_date: | |
| if str(expiry_date) < datetime.now().date().isoformat(): | |
| return None | |
| # Parse JSON data from database | |
| data_dict = json.loads(data_json) if isinstance(data_json, str) else data_json | |
| metadata_dict = json.loads(metadata_json) if isinstance(metadata_json, str) else (metadata_json or {}) | |
| # Create DatabaseEntry | |
| entry = DatabaseEntry( | |
| date=str(date_val), | |
| data_type=str(data_type_val), | |
| ticker=str(ticker_val), | |
| data=data_dict, | |
| created_at=str(created_at), | |
| updated_at=str(updated_at), | |
| expiry_date=str(expiry_date) if expiry_date else None, | |
| metadata=metadata_dict | |
| ) | |
| return entry | |
| except Exception as e: | |
| print(f"Error retrieving entry: {e}") | |
| return None | |
| def query(self, | |
| date_from: Optional[str] = None, | |
| date_to: Optional[str] = None, | |
| data_type: Optional[str] = None, | |
| ticker: Optional[str] = None, | |
| limit: Optional[int] = None, | |
| include_expired: bool = False) -> List[DatabaseEntry]: | |
| """ | |
| Query database with flexible filters across all tables | |
| Args: | |
| date_from: Start date (inclusive) | |
| date_to: End date (inclusive) | |
| data_type: Filter by data type (e.g., 'earnings', 'news', 'fundamental_analysis') | |
| ticker: Filter by ticker | |
| limit: Max results | |
| include_expired: Whether to include expired entries | |
| Returns: | |
| List of DatabaseEntry objects | |
| """ | |
| try: | |
| conn = self._create_connection() | |
| if not conn: | |
| return [] | |
| cursor = conn.cursor() | |
| entries = [] | |
| # Determine which tables to query | |
| tables_to_query = [] | |
| if data_type: | |
| # Query specific table based on data_type | |
| table_name = self._get_table_name(data_type) | |
| tables_to_query.append((table_name, data_type)) | |
| else: | |
| # Query all tables | |
| tables_to_query = [ | |
| ('calendar', None), # Will get all calendar events | |
| ('news', 'news'), | |
| ('fundamental_analysis', 'fundamental_analysis') | |
| ] | |
| # Query each table | |
| for table_name, specific_type in tables_to_query: | |
| params = [] | |
| if table_name == 'calendar': | |
| # Calendar has event_type column | |
| query = "SELECT date, event_type, ticker, data, created_at, updated_at, expiry_date, metadata FROM calendar WHERE 1=1" | |
| if specific_type: # Specific calendar event type | |
| query += " AND event_type = %s" | |
| params.append(specific_type) | |
| else: | |
| # News and fundamental_analysis don't have event_type | |
| query = f"SELECT date, ticker, data, created_at, updated_at, expiry_date, metadata FROM {table_name} WHERE 1=1" | |
| if date_from: | |
| query += " AND date >= %s" | |
| params.append(date_from) | |
| if date_to: | |
| query += " AND date <= %s" | |
| params.append(date_to) | |
| if ticker: | |
| query += " AND ticker = %s" | |
| params.append(ticker) | |
| if not include_expired: | |
| query += " AND (expiry_date IS NULL OR expiry_date >= %s)" | |
| params.append(datetime.now().date().isoformat()) | |
| query += " ORDER BY date DESC, created_at DESC" | |
| if limit and len(tables_to_query) == 1: | |
| # Only apply limit if querying a single table | |
| query += f" LIMIT {limit}" | |
| cursor.execute(query, tuple(params)) | |
| results = cursor.fetchall() | |
| # Parse results based on table structure | |
| for row in results: | |
| try: | |
| if table_name == 'calendar': | |
| date_val, event_type_val, ticker_val, data_json, created_at, updated_at, expiry_date, metadata_json = row | |
| data_type_val = event_type_val | |
| else: | |
| date_val, ticker_val, data_json, created_at, updated_at, expiry_date, metadata_json = row | |
| data_type_val = specific_type | |
| # Parse JSON data | |
| data_dict = json.loads(data_json) if isinstance(data_json, str) else data_json | |
| metadata_dict = json.loads(metadata_json) if isinstance(metadata_json, str) else (metadata_json or {}) | |
| entry = DatabaseEntry( | |
| date=str(date_val), | |
| data_type=str(data_type_val), | |
| ticker=str(ticker_val), | |
| data=data_dict, | |
| created_at=str(created_at), | |
| updated_at=str(updated_at), | |
| expiry_date=str(expiry_date) if expiry_date else None, | |
| metadata=metadata_dict | |
| ) | |
| entries.append(entry) | |
| except Exception as e: | |
| print(f"Error loading entry from {table_name}: {e}") | |
| continue | |
| cursor.close() | |
| conn.close() | |
| # Sort all entries by date and apply limit if needed | |
| entries.sort(key=lambda x: (x.date, x.created_at), reverse=True) | |
| if limit: | |
| entries = entries[:limit] | |
| return entries | |
| except Exception as e: | |
| print(f"Error querying database: {e}") | |
| return [] | |
| def delete(self, date_str: str, data_type: str, ticker: str) -> bool: | |
| """Delete entry by date, type, and ticker""" | |
| try: | |
| # Generate key | |
| key_string = f"{date_str}_{data_type}_{ticker}" | |
| entry_key = hashlib.md5(key_string.encode()).hexdigest() | |
| table_name = self._get_table_name(data_type) | |
| conn = self._create_connection() | |
| if not conn: | |
| return False | |
| cursor = conn.cursor() | |
| cursor.execute(f'DELETE FROM {table_name} WHERE entry_key = %s', (entry_key,)) | |
| conn.commit() | |
| conn.close() | |
| return True | |
| except Exception as e: | |
| print(f"Error deleting entry: {e}") | |
| return False | |
| def clean_expired(self) -> int: | |
| """Remove expired entries""" | |
| try: | |
| conn = self._create_connection() | |
| if not conn: | |
| return 0 | |
| cursor = conn.cursor() | |
| total_cleaned = 0 | |
| for table_name in ['calendar', 'news', 'fundamental_analysis']: | |
| cursor.execute(f''' | |
| DELETE FROM {table_name} | |
| WHERE expiry_date IS NOT NULL AND expiry_date < %s | |
| ''', (datetime.now().date().isoformat(),)) | |
| total_cleaned += cursor.rowcount | |
| conn.commit() | |
| conn.close() | |
| print(f"✓ Cleaned {total_cleaned} expired entries") | |
| return total_cleaned | |
| except Exception as e: | |
| print(f"Error cleaning expired entries: {e}") | |
| return 0 | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get database statistics across all tables""" | |
| try: | |
| conn = self._create_connection() | |
| if not conn: | |
| return {} | |
| cursor = conn.cursor() | |
| # Initialize counters | |
| total_entries = 0 | |
| by_type = {} | |
| all_tickers = {} | |
| total_size = 0 | |
| expired_count = 0 | |
| min_date = None | |
| max_date = None | |
| # Query each table | |
| for table_name in ['calendar', 'news', 'fundamental_analysis']: | |
| # Count entries | |
| cursor.execute(f'SELECT COUNT(*) FROM {table_name}') | |
| table_count = cursor.fetchone()[0] | |
| total_entries += table_count | |
| if table_name == 'calendar': | |
| # Get counts by event_type | |
| cursor.execute('SELECT event_type, COUNT(*) FROM calendar GROUP BY event_type') | |
| for event_type, count in cursor.fetchall(): | |
| by_type[event_type] = count | |
| else: | |
| # For news and fundamental_analysis, use table name as type | |
| by_type[table_name] = table_count | |
| # Get ticker counts | |
| cursor.execute(f'SELECT ticker, COUNT(*) FROM {table_name} GROUP BY ticker') | |
| for ticker, count in cursor.fetchall(): | |
| all_tickers[ticker] = all_tickers.get(ticker, 0) + count | |
| # Get data size | |
| cursor.execute(f'SELECT SUM(LENGTH(data)) FROM {table_name}') | |
| table_size = cursor.fetchone()[0] or 0 | |
| total_size += table_size | |
| # Count expired entries | |
| cursor.execute(f''' | |
| SELECT COUNT(*) FROM {table_name} | |
| WHERE expiry_date IS NOT NULL AND expiry_date < %s | |
| ''', (datetime.now().date().isoformat(),)) | |
| expired_count += cursor.fetchone()[0] | |
| # Get date range | |
| cursor.execute(f'SELECT MIN(date), MAX(date) FROM {table_name}') | |
| table_date_range = cursor.fetchone() | |
| if table_date_range[0]: | |
| if min_date is None or table_date_range[0] < min_date: | |
| min_date = table_date_range[0] | |
| if max_date is None or table_date_range[1] > max_date: | |
| max_date = table_date_range[1] | |
| # Get top 10 tickers | |
| top_tickers = dict(sorted(all_tickers.items(), key=lambda x: x[1], reverse=True)[:10]) | |
| conn.close() | |
| stats = { | |
| 'total_entries': total_entries, | |
| 'by_type': by_type, | |
| 'top_tickers': top_tickers, | |
| 'total_size_bytes': total_size, | |
| 'total_size_mb': round(total_size / (1024 * 1024), 2), | |
| 'expired_entries': expired_count, | |
| 'date_range': {'from': str(min_date), 'to': str(max_date)} if min_date else None, | |
| 'compression': 'enabled' if self.compress else 'disabled' | |
| } | |
| return stats | |
| except Exception as e: | |
| print(f"Error getting stats: {e}") | |
| return {} | |
| def clear_all(self) -> bool: | |
| """Clear all data (use with caution!)""" | |
| try: | |
| conn = self._create_connection() | |
| if not conn: | |
| return False | |
| cursor = conn.cursor() | |
| # Truncate all tables | |
| for table_name in ['calendar', 'news', 'fundamental_analysis', 'signals']: | |
| cursor.execute(f'TRUNCATE TABLE {table_name}') | |
| conn.commit() | |
| conn.close() | |
| print("✓ All data cleared") | |
| return True | |
| except Exception as e: | |
| print(f"Error clearing data: {e}") | |
| return False | |
| # Global database instance | |
| db_instance = None | |
| def get_database(db_dir: str = "database", compress: bool = False) -> LocalDatabase: | |
| """Get or create database instance""" | |
| global db_instance | |
| if db_instance is None or db_instance.db_dir != Path(db_dir): | |
| db_instance = LocalDatabase(db_dir=db_dir, compress=compress) | |
| return db_instance |