File size: 5,960 Bytes
4851501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
"""
Session Store Service
Thread-safe session-scoped storage for user layers and context.
Replaces global SESSION_LAYERS with per-session isolation.
"""
import logging
import threading
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any
logger = logging.getLogger(__name__)
class SessionStore:
"""
Thread-safe session-scoped storage with TTL expiration.
Each session maintains its own:
- layers: Map layers created by the user
- context: Optional conversation context
Sessions expire after configurable TTL (default 2 hours).
"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(SessionStore, cls).__new__(cls)
cls._instance.initialized = False
return cls._instance
def __init__(self, ttl_hours: int = 2, max_layers_per_session: int = 15):
if self.initialized:
return
self._sessions: Dict[str, dict] = {}
self._lock = threading.Lock()
self.ttl = timedelta(hours=ttl_hours)
self.max_layers = max_layers_per_session
self.initialized = True
logger.info(f"SessionStore initialized with TTL={ttl_hours}h, max_layers={max_layers_per_session}")
def _get_or_create_session(self, session_id: str) -> dict:
"""Get existing session or create new one."""
if session_id not in self._sessions:
self._sessions[session_id] = {
"layers": [],
"created": datetime.now(),
"accessed": datetime.now()
}
return self._sessions[session_id]
def get_layers(self, session_id: str) -> List[dict]:
"""Get all layers for a session."""
with self._lock:
session = self._get_or_create_session(session_id)
session["accessed"] = datetime.now()
return session["layers"].copy()
def add_layer(self, session_id: str, layer: dict) -> None:
"""
Add a layer to a session.
Enforces max_layers limit by removing oldest layers.
"""
with self._lock:
session = self._get_or_create_session(session_id)
session["layers"].append(layer)
session["accessed"] = datetime.now()
# Enforce layer limit
while len(session["layers"]) > self.max_layers:
removed = session["layers"].pop(0)
logger.debug(f"Session {session_id[:8]}: removed oldest layer {removed.get('name', 'unknown')}")
def update_layer(self, session_id: str, layer_id: str, updates: dict) -> bool:
"""
Update an existing layer in a session.
Returns True if layer was found and updated.
"""
with self._lock:
session = self._sessions.get(session_id)
if not session:
return False
for layer in session["layers"]:
if layer.get("id") == layer_id:
layer.update(updates)
session["accessed"] = datetime.now()
return True
return False
def remove_layer(self, session_id: str, layer_id: str) -> bool:
"""
Remove a layer from a session.
Returns True if layer was found and removed.
"""
with self._lock:
session = self._sessions.get(session_id)
if not session:
return False
original_len = len(session["layers"])
session["layers"] = [l for l in session["layers"] if l.get("id") != layer_id]
session["accessed"] = datetime.now()
return len(session["layers"]) < original_len
def clear_session(self, session_id: str) -> None:
"""Clear all data for a session."""
with self._lock:
if session_id in self._sessions:
del self._sessions[session_id]
def get_layer_by_index(self, session_id: str, index: int) -> Optional[dict]:
"""Get a specific layer by 1-based index (for user references like 'Layer 1')."""
with self._lock:
session = self._sessions.get(session_id)
if not session:
return None
layers = session["layers"]
if 1 <= index <= len(layers):
return layers[index - 1].copy()
return None
def cleanup_expired(self) -> int:
"""
Remove sessions older than TTL.
Returns number of expired sessions removed.
"""
with self._lock:
now = datetime.now()
expired = [
sid for sid, data in self._sessions.items()
if now - data.get("accessed", data["created"]) > self.ttl
]
for sid in expired:
del self._sessions[sid]
if expired:
logger.info(f"Cleaned up {len(expired)} expired sessions.")
return len(expired)
def get_stats(self) -> dict:
"""Return statistics about active sessions."""
with self._lock:
total_layers = sum(len(s["layers"]) for s in self._sessions.values())
return {
"active_sessions": len(self._sessions),
"total_layers": total_layers,
"ttl_hours": self.ttl.total_seconds() / 3600,
"max_layers_per_session": self.max_layers
}
# Singleton accessor
_session_store: Optional[SessionStore] = None
def get_session_store() -> SessionStore:
"""Get the singleton session store instance."""
global _session_store
if _session_store is None:
_session_store = SessionStore()
return _session_store
|