Spaces:
Running
Running
| import json | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Optional | |
| from datetime import datetime | |
| import sqlite3 | |
| class MemoryStore: | |
| """Persistent memory store for agent context and user preferences""" | |
| def __init__(self, db_path: str = "data/memory.db"): | |
| """Initialize memory store with SQLite""" | |
| Path(db_path).parent.mkdir(parents=True, exist_ok=True) | |
| self.db_path = db_path | |
| self._init_db() | |
| def _init_db(self): | |
| """Initialize database schema""" | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| # Create memories table | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS memories ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| content TEXT NOT NULL, | |
| memory_type TEXT, | |
| metadata TEXT, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| importance INTEGER DEFAULT 5 | |
| ) | |
| ''') | |
| # Create user preferences table | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS preferences ( | |
| key TEXT PRIMARY KEY, | |
| value TEXT NOT NULL, | |
| updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| ''') | |
| # Create context table for short-term memory | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS context ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id TEXT, | |
| content TEXT NOT NULL, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| ''') | |
| conn.commit() | |
| conn.close() | |
| def add_memory( | |
| self, | |
| content: str, | |
| memory_type: str = 'general', | |
| metadata: Dict[str, Any] = None, | |
| importance: int = 5 | |
| ) -> int: | |
| """ | |
| Add a memory to long-term storage | |
| Args: | |
| content: Memory content | |
| memory_type: Type of memory (general, task, preference, etc.) | |
| metadata: Additional metadata | |
| importance: Importance score (1-10) | |
| Returns: | |
| Memory ID | |
| """ | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| metadata_json = json.dumps(metadata) if metadata else '{}' | |
| cursor.execute(''' | |
| INSERT INTO memories (content, memory_type, metadata, importance) | |
| VALUES (?, ?, ?, ?) | |
| ''', (content, memory_type, metadata_json, importance)) | |
| memory_id = cursor.lastrowid | |
| conn.commit() | |
| conn.close() | |
| return memory_id | |
| def get_memories( | |
| self, | |
| memory_type: Optional[str] = None, | |
| limit: int = 10, | |
| min_importance: int = 0 | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve memories | |
| Args: | |
| memory_type: Filter by memory type | |
| limit: Maximum number of memories to return | |
| min_importance: Minimum importance score | |
| Returns: | |
| List of memories | |
| """ | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| query = ''' | |
| SELECT id, content, memory_type, metadata, created_at, importance | |
| FROM memories | |
| WHERE importance >= ? | |
| ''' | |
| params = [min_importance] | |
| if memory_type: | |
| query += ' AND memory_type = ?' | |
| params.append(memory_type) | |
| query += ' ORDER BY importance DESC, created_at DESC LIMIT ?' | |
| params.append(limit) | |
| cursor.execute(query, params) | |
| rows = cursor.fetchall() | |
| conn.close() | |
| memories = [] | |
| for row in rows: | |
| memories.append({ | |
| 'id': row[0], | |
| 'content': row[1], | |
| 'memory_type': row[2], | |
| 'metadata': json.loads(row[3]), | |
| 'created_at': row[4], | |
| 'importance': row[5] | |
| }) | |
| return memories | |
| def get_relevant_memories(self, query: str, k: int = 5) -> str: | |
| """ | |
| Get memories relevant to a query | |
| Args: | |
| query: Search query | |
| k: Number of memories to return | |
| Returns: | |
| Formatted string of relevant memories | |
| """ | |
| # Simple keyword-based search (can be enhanced with embeddings) | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| # Search for memories containing query keywords | |
| keywords = query.lower().split() | |
| memories = [] | |
| for keyword in keywords[:3]: # Limit to 3 keywords | |
| cursor.execute(''' | |
| SELECT content, memory_type, importance | |
| FROM memories | |
| WHERE LOWER(content) LIKE ? | |
| ORDER BY importance DESC | |
| LIMIT ? | |
| ''', (f'%{keyword}%', k)) | |
| memories.extend(cursor.fetchall()) | |
| conn.close() | |
| if not memories: | |
| return "No relevant memories found." | |
| # Format memories | |
| unique_memories = list({m[0]: m for m in memories}.values())[:k] | |
| formatted = [] | |
| for content, mem_type, importance in unique_memories: | |
| formatted.append(f"[{mem_type}] {content}") | |
| return "\n".join(formatted) | |
| def set_preference(self, key: str, value: Any): | |
| """Set a user preference""" | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| value_json = json.dumps(value) | |
| cursor.execute(''' | |
| INSERT OR REPLACE INTO preferences (key, value, updated_at) | |
| VALUES (?, ?, CURRENT_TIMESTAMP) | |
| ''', (key, value_json)) | |
| conn.commit() | |
| conn.close() | |
| def get_preference(self, key: str, default: Any = None) -> Any: | |
| """Get a user preference""" | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT value FROM preferences WHERE key = ?', (key,)) | |
| row = cursor.fetchone() | |
| conn.close() | |
| if row: | |
| return json.loads(row[0]) | |
| return default | |
| def get_all_preferences(self) -> Dict[str, Any]: | |
| """Get all user preferences""" | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT key, value FROM preferences') | |
| rows = cursor.fetchall() | |
| conn.close() | |
| return {key: json.loads(value) for key, value in rows} | |
| def add_context(self, session_id: str, content: str): | |
| """Add to short-term context""" | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| INSERT INTO context (session_id, content) | |
| VALUES (?, ?) | |
| ''', (session_id, content)) | |
| conn.commit() | |
| conn.close() | |
| def get_context(self, session_id: str, limit: int = 10) -> List[str]: | |
| """Get recent context for a session""" | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| SELECT content FROM context | |
| WHERE session_id = ? | |
| ORDER BY created_at DESC | |
| LIMIT ? | |
| ''', (session_id, limit)) | |
| rows = cursor.fetchall() | |
| conn.close() | |
| return [row[0] for row in reversed(rows)] | |
| def clear_old_context(self, days: int = 7): | |
| """Clear context older than specified days""" | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| DELETE FROM context | |
| WHERE created_at < datetime('now', ? || ' days') | |
| ''', (f'-{days}',)) | |
| conn.commit() | |
| conn.close() |