File size: 12,196 Bytes
7602502 3fac7d8 7602502 3fac7d8 7602502 3fac7d8 7602502 3fac7d8 7602502 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 |
"""Pinecone client for uploading and managing standard records."""
from __future__ import annotations
import time
from datetime import datetime, timezone
from pathlib import Path
from collections.abc import Callable
from typing import Any
from loguru import logger
from pinecone import Pinecone
from pinecone.exceptions import PineconeException
from src.mcp_config import get_mcp_settings
from tools.pinecone_models import PineconeRecord
settings = get_mcp_settings()
class PineconeClient:
"""Client for interacting with Pinecone index."""
def __init__(self) -> None:
"""Initialize Pinecone SDK from config settings."""
api_key = settings.pinecone_api_key
if not api_key:
raise ValueError("PINECONE_API_KEY environment variable not set")
self.pc = Pinecone(api_key=api_key)
self.index_name = settings.pinecone_index_name
self.namespace = settings.pinecone_namespace
self._index = None
@property
def index(self):
"""Get the index object, creating it if needed."""
if self._index is None:
self._index = self.pc.Index(self.index_name)
return self._index
def validate_index(self) -> None:
"""
Check index exists with pc.has_index(), raise helpful error if not.
Raises:
ValueError: If index does not exist, with instructions to create it.
"""
if not self.pc.has_index(name=self.index_name):
raise ValueError(
f"Index '{self.index_name}' not found. Create it with:\n"
f"pc index create -n {self.index_name} -m cosine -c aws -r us-east-1 "
f"--model llama-text-embed-v2 --field_map text=content"
)
def ensure_index_exists(self) -> bool:
"""
Check if index exists, create it if not.
Creates the index with integrated embeddings using llama-text-embed-v2 model.
Returns:
True if index was created, False if it already existed.
"""
if self.pc.has_index(name=self.index_name):
logger.info(f"Index '{self.index_name}' already exists")
return False
logger.info(f"Creating index '{self.index_name}' with integrated embeddings...")
self.pc.create_index_for_model(
name=self.index_name,
cloud="aws",
region="us-east-1",
embed={
"model": "llama-text-embed-v2",
"field_map": {"text": "content"},
},
)
logger.info(f"Successfully created index '{self.index_name}'")
return True
def get_index_stats(self) -> dict[str, Any]:
"""
Get index statistics including vector count and namespaces.
Returns:
Dictionary with index stats including total_vector_count and namespaces.
"""
stats = self.index.describe_index_stats()
return {
"total_vector_count": stats.total_vector_count,
"namespaces": dict(stats.namespaces) if stats.namespaces else {},
}
@staticmethod
def exponential_backoff_retry(
func: Callable[[], Any], max_retries: int = 5
) -> Any:
"""
Retry function with exponential backoff on 429/5xx, fail on 4xx.
Args:
func: Function to retry (should be a callable that takes no args)
max_retries: Maximum number of retry attempts
Returns:
Result of func()
Raises:
PineconeException: If retries exhausted or non-retryable error
"""
for attempt in range(max_retries):
try:
return func()
except PineconeException as e:
status_code = getattr(e, "status", None)
# Only retry transient errors
if status_code and (status_code >= 500 or status_code == 429):
if attempt < max_retries - 1:
delay = min(2 ** attempt, 60) # Cap at 60s
logger.warning(
f"Retryable error (status {status_code}), "
f"retrying in {delay}s (attempt {attempt + 1}/{max_retries})"
)
time.sleep(delay)
else:
logger.error(
f"Max retries ({max_retries}) exceeded for retryable error"
)
raise
else:
# Don't retry client errors
logger.error(f"Non-retryable error (status {status_code}): {e}")
raise
except Exception as e:
# Non-Pinecone exceptions should not be retried
logger.error(f"Non-retryable exception: {e}")
raise
def batch_upsert(
self, records: list[PineconeRecord], batch_size: int = 96
) -> None:
"""
Upsert records in batches of specified size with rate limiting.
Args:
records: List of PineconeRecord objects to upsert
batch_size: Number of records per batch (default: 96)
"""
if not records:
logger.info("No records to upsert")
return
total_batches = (len(records) + batch_size - 1) // batch_size
logger.info(
f"Upserting {len(records)} records in {total_batches} batch(es) "
f"(batch size: {batch_size})"
)
for i in range(0, len(records), batch_size):
batch = records[i : i + batch_size]
batch_num = (i // batch_size) + 1
# Convert PineconeRecord models to dict format for Pinecone
batch_dicts = [self._record_to_dict(record) for record in batch]
logger.debug(f"Upserting batch {batch_num}/{total_batches} ({len(batch)} records)")
# Retry with exponential backoff
self.exponential_backoff_retry(
lambda b=batch_dicts: self.index.upsert_records(
namespace=self.namespace, records=b
)
)
# Rate limiting between batches
if i + batch_size < len(records):
time.sleep(0.1)
logger.info(f"Successfully upserted {len(records)} records")
@staticmethod
def _record_to_dict(record: PineconeRecord) -> dict[str, Any]:
"""
Convert PineconeRecord model to dict format for Pinecone API.
Handles optional fields by omitting them if None. Pinecone doesn't accept
null values for metadata fields, so parent_id must be omitted entirely
when None (for root nodes).
Args:
record: PineconeRecord model instance
Returns:
Dictionary ready for Pinecone upsert_records
"""
# Use by_alias=True to serialize 'id' as '_id' per model serialization_alias
record_dict = record.model_dump(exclude_none=False, by_alias=True)
# Remove None values for optional fields
optional_fields = {
"asn_identifier",
"statement_notation",
"statement_label",
"normalized_subject",
"publication_status",
"parent_id", # Must be omitted when None (Pinecone doesn't accept null)
"document_id",
"document_valid",
}
for field in optional_fields:
if record_dict.get(field) is None:
record_dict.pop(field, None)
return record_dict
def search_standards(
self,
query_text: str,
top_k: int = 5,
grade: str | None = None,
) -> list[dict]:
"""
Perform semantic search over standards.
Args:
query_text: Natural language query
top_k: Maximum number of results
grade: Optional grade filter
Returns:
List of result dictionaries with metadata and scores
"""
# Build filter dictionary dynamically
# Always filter to only leaf nodes (actual standards, not parent categories)
filter_parts = [{"is_leaf": {"$eq": True}}]
if grade:
filter_parts.append({"education_levels": {"$in": [grade]}})
filter_dict = None
if len(filter_parts) == 1:
filter_dict = filter_parts[0]
elif len(filter_parts) == 2:
filter_dict = {"$and": filter_parts}
# Build query dictionary
query_dict: dict[str, Any] = {
"inputs": {"text": query_text},
"top_k": top_k * 2, # Get more candidates for reranking
}
if filter_dict:
query_dict["filter"] = filter_dict
# Call search with reranking
results = self.index.search(
namespace=self.namespace,
query=query_dict,
rerank={"model": "bge-reranker-v2-m3", "top_n": top_k, "rank_fields": ["content"]},
)
# Parse results
hits = results.get("result", {}).get("hits", [])
parsed_results = []
for hit in hits:
result_dict = {
"_id": hit["_id"],
"score": hit["_score"],
**hit.get("fields", {}),
}
parsed_results.append(result_dict)
return parsed_results
def fetch_standard(self, standard_id: str) -> dict | None:
"""
Fetch a standard by its GUID (_id field only).
This method performs a direct lookup using Pinecone's fetch() API, which only
works with the standard's GUID (_id field). It does NOT search by statement_notation,
asn_identifier, or any other metadata fields.
Args:
standard_id: Standard GUID (_id field) - must be the exact GUID format
(e.g., "EA60C8D165F6481B90BFF782CE193F93")
Returns:
Standard dictionary with metadata, or None if not found
"""
result = self.index.fetch(ids=[standard_id], namespace=self.namespace)
# Extract vectors from FetchResponse
# FetchResponse.vectors is a dict mapping ID to Vector objects
vectors = result.vectors
if not vectors or standard_id not in vectors:
return None
vector = vectors[standard_id]
# Extract metadata from Vector object
# Vector has: id, values (embedding), and metadata (dict with all fields)
metadata = vector.metadata or {}
vector_id = vector.id
# Combine _id with all metadata fields
record_dict = {
"_id": vector_id,
**metadata,
}
return record_dict
@staticmethod
def is_uploaded(set_dir: Path) -> bool:
"""
Check for .pinecone_uploaded marker file.
Args:
set_dir: Path to standard set directory
Returns:
True if marker file exists, False otherwise
"""
marker_file = set_dir / ".pinecone_uploaded"
return marker_file.exists()
@staticmethod
def mark_uploaded(set_dir: Path) -> None:
"""
Create marker file with ISO 8601 timestamp.
Args:
set_dir: Path to standard set directory
"""
marker_file = set_dir / ".pinecone_uploaded"
timestamp = datetime.now(timezone.utc).isoformat()
marker_file.write_text(timestamp, encoding="utf-8")
logger.debug(f"Created upload marker: {marker_file}")
@staticmethod
def get_upload_timestamp(set_dir: Path) -> str | None:
"""
Read timestamp from marker file.
Args:
set_dir: Path to standard set directory
Returns:
ISO 8601 timestamp string if marker exists, None otherwise
"""
marker_file = set_dir / ".pinecone_uploaded"
if not marker_file.exists():
return None
try:
return marker_file.read_text(encoding="utf-8").strip()
except Exception as e:
logger.warning(f"Failed to read upload marker {marker_file}: {e}")
return None
|