File size: 5,960 Bytes
4851501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
Session Store Service

Thread-safe session-scoped storage for user layers and context.
Replaces global SESSION_LAYERS with per-session isolation.
"""

import logging
import threading
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any


logger = logging.getLogger(__name__)


class SessionStore:
    """
    Thread-safe session-scoped storage with TTL expiration.
    
    Each session maintains its own:
    - layers: Map layers created by the user
    - context: Optional conversation context
    
    Sessions expire after configurable TTL (default 2 hours).
    """
    
    _instance = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(SessionStore, cls).__new__(cls)
            cls._instance.initialized = False
        return cls._instance
    
    def __init__(self, ttl_hours: int = 2, max_layers_per_session: int = 15):
        if self.initialized:
            return
        
        self._sessions: Dict[str, dict] = {}
        self._lock = threading.Lock()
        self.ttl = timedelta(hours=ttl_hours)
        self.max_layers = max_layers_per_session
        self.initialized = True
        
        logger.info(f"SessionStore initialized with TTL={ttl_hours}h, max_layers={max_layers_per_session}")
    
    def _get_or_create_session(self, session_id: str) -> dict:
        """Get existing session or create new one."""
        if session_id not in self._sessions:
            self._sessions[session_id] = {
                "layers": [],
                "created": datetime.now(),
                "accessed": datetime.now()
            }
        return self._sessions[session_id]
    
    def get_layers(self, session_id: str) -> List[dict]:
        """Get all layers for a session."""
        with self._lock:
            session = self._get_or_create_session(session_id)
            session["accessed"] = datetime.now()
            return session["layers"].copy()
    
    def add_layer(self, session_id: str, layer: dict) -> None:
        """
        Add a layer to a session.
        
        Enforces max_layers limit by removing oldest layers.
        """
        with self._lock:
            session = self._get_or_create_session(session_id)
            session["layers"].append(layer)
            session["accessed"] = datetime.now()
            
            # Enforce layer limit
            while len(session["layers"]) > self.max_layers:
                removed = session["layers"].pop(0)
                logger.debug(f"Session {session_id[:8]}: removed oldest layer {removed.get('name', 'unknown')}")
    
    def update_layer(self, session_id: str, layer_id: str, updates: dict) -> bool:
        """
        Update an existing layer in a session.
        
        Returns True if layer was found and updated.
        """
        with self._lock:
            session = self._sessions.get(session_id)
            if not session:
                return False
            
            for layer in session["layers"]:
                if layer.get("id") == layer_id:
                    layer.update(updates)
                    session["accessed"] = datetime.now()
                    return True
            
            return False
    
    def remove_layer(self, session_id: str, layer_id: str) -> bool:
        """
        Remove a layer from a session.
        
        Returns True if layer was found and removed.
        """
        with self._lock:
            session = self._sessions.get(session_id)
            if not session:
                return False
            
            original_len = len(session["layers"])
            session["layers"] = [l for l in session["layers"] if l.get("id") != layer_id]
            session["accessed"] = datetime.now()
            
            return len(session["layers"]) < original_len
    
    def clear_session(self, session_id: str) -> None:
        """Clear all data for a session."""
        with self._lock:
            if session_id in self._sessions:
                del self._sessions[session_id]
    
    def get_layer_by_index(self, session_id: str, index: int) -> Optional[dict]:
        """Get a specific layer by 1-based index (for user references like 'Layer 1')."""
        with self._lock:
            session = self._sessions.get(session_id)
            if not session:
                return None
            
            layers = session["layers"]
            if 1 <= index <= len(layers):
                return layers[index - 1].copy()
            
            return None
    
    def cleanup_expired(self) -> int:
        """
        Remove sessions older than TTL.
        
        Returns number of expired sessions removed.
        """
        with self._lock:
            now = datetime.now()
            expired = [
                sid for sid, data in self._sessions.items()
                if now - data.get("accessed", data["created"]) > self.ttl
            ]
            
            for sid in expired:
                del self._sessions[sid]
            
            if expired:
                logger.info(f"Cleaned up {len(expired)} expired sessions.")
            
            return len(expired)
    
    def get_stats(self) -> dict:
        """Return statistics about active sessions."""
        with self._lock:
            total_layers = sum(len(s["layers"]) for s in self._sessions.values())
            
            return {
                "active_sessions": len(self._sessions),
                "total_layers": total_layers,
                "ttl_hours": self.ttl.total_seconds() / 3600,
                "max_layers_per_session": self.max_layers
            }


# Singleton accessor
_session_store: Optional[SessionStore] = None


def get_session_store() -> SessionStore:
    """Get the singleton session store instance."""
    global _session_store
    if _session_store is None:
        _session_store = SessionStore()
    return _session_store