File size: 12,196 Bytes
7602502
 
 
 
 
 
 
 
 
 
 
 
 
 
3fac7d8
7602502
 
3fac7d8
7602502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fac7d8
 
7602502
 
 
 
 
 
 
3fac7d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7602502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
"""Pinecone client for uploading and managing standard records."""

from __future__ import annotations

import time
from datetime import datetime, timezone
from pathlib import Path
from collections.abc import Callable
from typing import Any

from loguru import logger
from pinecone import Pinecone
from pinecone.exceptions import PineconeException

from src.mcp_config import get_mcp_settings
from tools.pinecone_models import PineconeRecord

settings = get_mcp_settings()


class PineconeClient:
    """Client for interacting with Pinecone index."""

    def __init__(self) -> None:
        """Initialize Pinecone SDK from config settings."""
        api_key = settings.pinecone_api_key
        if not api_key:
            raise ValueError("PINECONE_API_KEY environment variable not set")

        self.pc = Pinecone(api_key=api_key)
        self.index_name = settings.pinecone_index_name
        self.namespace = settings.pinecone_namespace
        self._index = None

    @property
    def index(self):
        """Get the index object, creating it if needed."""
        if self._index is None:
            self._index = self.pc.Index(self.index_name)
        return self._index

    def validate_index(self) -> None:
        """
        Check index exists with pc.has_index(), raise helpful error if not.

        Raises:
            ValueError: If index does not exist, with instructions to create it.
        """
        if not self.pc.has_index(name=self.index_name):
            raise ValueError(
                f"Index '{self.index_name}' not found. Create it with:\n"
                f"pc index create -n {self.index_name} -m cosine -c aws -r us-east-1 "
                f"--model llama-text-embed-v2 --field_map text=content"
            )

    def ensure_index_exists(self) -> bool:
        """
        Check if index exists, create it if not.

        Creates the index with integrated embeddings using llama-text-embed-v2 model.

        Returns:
            True if index was created, False if it already existed.
        """
        if self.pc.has_index(name=self.index_name):
            logger.info(f"Index '{self.index_name}' already exists")
            return False

        logger.info(f"Creating index '{self.index_name}' with integrated embeddings...")
        self.pc.create_index_for_model(
            name=self.index_name,
            cloud="aws",
            region="us-east-1",
            embed={
                "model": "llama-text-embed-v2",
                "field_map": {"text": "content"},
            },
        )
        logger.info(f"Successfully created index '{self.index_name}'")
        return True

    def get_index_stats(self) -> dict[str, Any]:
        """
        Get index statistics including vector count and namespaces.

        Returns:
            Dictionary with index stats including total_vector_count and namespaces.
        """
        stats = self.index.describe_index_stats()
        return {
            "total_vector_count": stats.total_vector_count,
            "namespaces": dict(stats.namespaces) if stats.namespaces else {},
        }

    @staticmethod
    def exponential_backoff_retry(
        func: Callable[[], Any], max_retries: int = 5
    ) -> Any:
        """
        Retry function with exponential backoff on 429/5xx, fail on 4xx.

        Args:
            func: Function to retry (should be a callable that takes no args)
            max_retries: Maximum number of retry attempts

        Returns:
            Result of func()

        Raises:
            PineconeException: If retries exhausted or non-retryable error
        """
        for attempt in range(max_retries):
            try:
                return func()
            except PineconeException as e:
                status_code = getattr(e, "status", None)
                # Only retry transient errors
                if status_code and (status_code >= 500 or status_code == 429):
                    if attempt < max_retries - 1:
                        delay = min(2 ** attempt, 60)  # Cap at 60s
                        logger.warning(
                            f"Retryable error (status {status_code}), "
                            f"retrying in {delay}s (attempt {attempt + 1}/{max_retries})"
                        )
                        time.sleep(delay)
                    else:
                        logger.error(
                            f"Max retries ({max_retries}) exceeded for retryable error"
                        )
                        raise
                else:
                    # Don't retry client errors
                    logger.error(f"Non-retryable error (status {status_code}): {e}")
                    raise
            except Exception as e:
                # Non-Pinecone exceptions should not be retried
                logger.error(f"Non-retryable exception: {e}")
                raise

    def batch_upsert(
        self, records: list[PineconeRecord], batch_size: int = 96
    ) -> None:
        """
        Upsert records in batches of specified size with rate limiting.

        Args:
            records: List of PineconeRecord objects to upsert
            batch_size: Number of records per batch (default: 96)
        """
        if not records:
            logger.info("No records to upsert")
            return

        total_batches = (len(records) + batch_size - 1) // batch_size
        logger.info(
            f"Upserting {len(records)} records in {total_batches} batch(es) "
            f"(batch size: {batch_size})"
        )

        for i in range(0, len(records), batch_size):
            batch = records[i : i + batch_size]
            batch_num = (i // batch_size) + 1

            # Convert PineconeRecord models to dict format for Pinecone
            batch_dicts = [self._record_to_dict(record) for record in batch]

            logger.debug(f"Upserting batch {batch_num}/{total_batches} ({len(batch)} records)")

            # Retry with exponential backoff
            self.exponential_backoff_retry(
                lambda b=batch_dicts: self.index.upsert_records(
                    namespace=self.namespace, records=b
                )
            )

            # Rate limiting between batches
            if i + batch_size < len(records):
                time.sleep(0.1)

        logger.info(f"Successfully upserted {len(records)} records")

    @staticmethod
    def _record_to_dict(record: PineconeRecord) -> dict[str, Any]:
        """
        Convert PineconeRecord model to dict format for Pinecone API.

        Handles optional fields by omitting them if None. Pinecone doesn't accept
        null values for metadata fields, so parent_id must be omitted entirely
        when None (for root nodes).

        Args:
            record: PineconeRecord model instance

        Returns:
            Dictionary ready for Pinecone upsert_records
        """
        # Use by_alias=True to serialize 'id' as '_id' per model serialization_alias
        record_dict = record.model_dump(exclude_none=False, by_alias=True)

        # Remove None values for optional fields
        optional_fields = {
            "asn_identifier",
            "statement_notation",
            "statement_label",
            "normalized_subject",
            "publication_status",
            "parent_id",  # Must be omitted when None (Pinecone doesn't accept null)
            "document_id",
            "document_valid",
        }
        for field in optional_fields:
            if record_dict.get(field) is None:
                record_dict.pop(field, None)

        return record_dict

    def search_standards(
        self,
        query_text: str,
        top_k: int = 5,
        grade: str | None = None,
    ) -> list[dict]:
        """
        Perform semantic search over standards.

        Args:
            query_text: Natural language query
            top_k: Maximum number of results
            grade: Optional grade filter

        Returns:
            List of result dictionaries with metadata and scores
        """
        # Build filter dictionary dynamically
        # Always filter to only leaf nodes (actual standards, not parent categories)
        filter_parts = [{"is_leaf": {"$eq": True}}]
        
        if grade:
            filter_parts.append({"education_levels": {"$in": [grade]}})

        filter_dict = None
        if len(filter_parts) == 1:
            filter_dict = filter_parts[0]
        elif len(filter_parts) == 2:
            filter_dict = {"$and": filter_parts}

        # Build query dictionary
        query_dict: dict[str, Any] = {
            "inputs": {"text": query_text},
            "top_k": top_k * 2,  # Get more candidates for reranking
        }
        if filter_dict:
            query_dict["filter"] = filter_dict

        # Call search with reranking
        results = self.index.search(
            namespace=self.namespace,
            query=query_dict,
            rerank={"model": "bge-reranker-v2-m3", "top_n": top_k, "rank_fields": ["content"]},
        )

        # Parse results
        hits = results.get("result", {}).get("hits", [])
        parsed_results = []
        for hit in hits:
            result_dict = {
                "_id": hit["_id"],
                "score": hit["_score"],
                **hit.get("fields", {}),
            }
            parsed_results.append(result_dict)

        return parsed_results

    def fetch_standard(self, standard_id: str) -> dict | None:
        """
        Fetch a standard by its GUID (_id field only).

        This method performs a direct lookup using Pinecone's fetch() API, which only
        works with the standard's GUID (_id field). It does NOT search by statement_notation,
        asn_identifier, or any other metadata fields.

        Args:
            standard_id: Standard GUID (_id field) - must be the exact GUID format
                (e.g., "EA60C8D165F6481B90BFF782CE193F93")

        Returns:
            Standard dictionary with metadata, or None if not found
        """
        result = self.index.fetch(ids=[standard_id], namespace=self.namespace)

        # Extract vectors from FetchResponse
        # FetchResponse.vectors is a dict mapping ID to Vector objects
        vectors = result.vectors
        
        if not vectors or standard_id not in vectors:
            return None

        vector = vectors[standard_id]
        
        # Extract metadata from Vector object
        # Vector has: id, values (embedding), and metadata (dict with all fields)
        metadata = vector.metadata or {}
        vector_id = vector.id

        # Combine _id with all metadata fields
        record_dict = {
            "_id": vector_id,
            **metadata,
        }

        return record_dict

    @staticmethod
    def is_uploaded(set_dir: Path) -> bool:
        """
        Check for .pinecone_uploaded marker file.

        Args:
            set_dir: Path to standard set directory

        Returns:
            True if marker file exists, False otherwise
        """
        marker_file = set_dir / ".pinecone_uploaded"
        return marker_file.exists()

    @staticmethod
    def mark_uploaded(set_dir: Path) -> None:
        """
        Create marker file with ISO 8601 timestamp.

        Args:
            set_dir: Path to standard set directory
        """
        marker_file = set_dir / ".pinecone_uploaded"
        timestamp = datetime.now(timezone.utc).isoformat()
        marker_file.write_text(timestamp, encoding="utf-8")
        logger.debug(f"Created upload marker: {marker_file}")

    @staticmethod
    def get_upload_timestamp(set_dir: Path) -> str | None:
        """
        Read timestamp from marker file.

        Args:
            set_dir: Path to standard set directory

        Returns:
            ISO 8601 timestamp string if marker exists, None otherwise
        """
        marker_file = set_dir / ".pinecone_uploaded"
        if not marker_file.exists():
            return None

        try:
            return marker_file.read_text(encoding="utf-8").strip()
        except Exception as e:
            logger.warning(f"Failed to read upload marker {marker_file}: {e}")
            return None