""" Session Store Service Thread-safe session-scoped storage for user layers and context. Replaces global SESSION_LAYERS with per-session isolation. """ import logging import threading from datetime import datetime, timedelta from typing import Dict, List, Optional, Any logger = logging.getLogger(__name__) class SessionStore: """ Thread-safe session-scoped storage with TTL expiration. Each session maintains its own: - layers: Map layers created by the user - context: Optional conversation context Sessions expire after configurable TTL (default 2 hours). """ _instance = None def __new__(cls): if cls._instance is None: cls._instance = super(SessionStore, cls).__new__(cls) cls._instance.initialized = False return cls._instance def __init__(self, ttl_hours: int = 2, max_layers_per_session: int = 15): if self.initialized: return self._sessions: Dict[str, dict] = {} self._lock = threading.Lock() self.ttl = timedelta(hours=ttl_hours) self.max_layers = max_layers_per_session self.initialized = True logger.info(f"SessionStore initialized with TTL={ttl_hours}h, max_layers={max_layers_per_session}") def _get_or_create_session(self, session_id: str) -> dict: """Get existing session or create new one.""" if session_id not in self._sessions: self._sessions[session_id] = { "layers": [], "created": datetime.now(), "accessed": datetime.now() } return self._sessions[session_id] def get_layers(self, session_id: str) -> List[dict]: """Get all layers for a session.""" with self._lock: session = self._get_or_create_session(session_id) session["accessed"] = datetime.now() return session["layers"].copy() def add_layer(self, session_id: str, layer: dict) -> None: """ Add a layer to a session. Enforces max_layers limit by removing oldest layers. """ with self._lock: session = self._get_or_create_session(session_id) session["layers"].append(layer) session["accessed"] = datetime.now() # Enforce layer limit while len(session["layers"]) > self.max_layers: removed = session["layers"].pop(0) logger.debug(f"Session {session_id[:8]}: removed oldest layer {removed.get('name', 'unknown')}") def update_layer(self, session_id: str, layer_id: str, updates: dict) -> bool: """ Update an existing layer in a session. Returns True if layer was found and updated. """ with self._lock: session = self._sessions.get(session_id) if not session: return False for layer in session["layers"]: if layer.get("id") == layer_id: layer.update(updates) session["accessed"] = datetime.now() return True return False def remove_layer(self, session_id: str, layer_id: str) -> bool: """ Remove a layer from a session. Returns True if layer was found and removed. """ with self._lock: session = self._sessions.get(session_id) if not session: return False original_len = len(session["layers"]) session["layers"] = [l for l in session["layers"] if l.get("id") != layer_id] session["accessed"] = datetime.now() return len(session["layers"]) < original_len def clear_session(self, session_id: str) -> None: """Clear all data for a session.""" with self._lock: if session_id in self._sessions: del self._sessions[session_id] def get_layer_by_index(self, session_id: str, index: int) -> Optional[dict]: """Get a specific layer by 1-based index (for user references like 'Layer 1').""" with self._lock: session = self._sessions.get(session_id) if not session: return None layers = session["layers"] if 1 <= index <= len(layers): return layers[index - 1].copy() return None def cleanup_expired(self) -> int: """ Remove sessions older than TTL. Returns number of expired sessions removed. """ with self._lock: now = datetime.now() expired = [ sid for sid, data in self._sessions.items() if now - data.get("accessed", data["created"]) > self.ttl ] for sid in expired: del self._sessions[sid] if expired: logger.info(f"Cleaned up {len(expired)} expired sessions.") return len(expired) def get_stats(self) -> dict: """Return statistics about active sessions.""" with self._lock: total_layers = sum(len(s["layers"]) for s in self._sessions.values()) return { "active_sessions": len(self._sessions), "total_layers": total_layers, "ttl_hours": self.ttl.total_seconds() / 3600, "max_layers_per_session": self.max_layers } # Singleton accessor _session_store: Optional[SessionStore] = None def get_session_store() -> SessionStore: """Get the singleton session store instance.""" global _session_store if _session_store is None: _session_store = SessionStore() return _session_store