common_core_mcp / src /pinecone_client.py
Lindow
add mcp and gradio app
3fac7d8
"""Pinecone client for uploading and managing standard records."""
from __future__ import annotations
import time
from datetime import datetime, timezone
from pathlib import Path
from collections.abc import Callable
from typing import Any
from loguru import logger
from pinecone import Pinecone
from pinecone.exceptions import PineconeException
from src.mcp_config import get_mcp_settings
from tools.pinecone_models import PineconeRecord
settings = get_mcp_settings()
class PineconeClient:
"""Client for interacting with Pinecone index."""
def __init__(self) -> None:
"""Initialize Pinecone SDK from config settings."""
api_key = settings.pinecone_api_key
if not api_key:
raise ValueError("PINECONE_API_KEY environment variable not set")
self.pc = Pinecone(api_key=api_key)
self.index_name = settings.pinecone_index_name
self.namespace = settings.pinecone_namespace
self._index = None
@property
def index(self):
"""Get the index object, creating it if needed."""
if self._index is None:
self._index = self.pc.Index(self.index_name)
return self._index
def validate_index(self) -> None:
"""
Check index exists with pc.has_index(), raise helpful error if not.
Raises:
ValueError: If index does not exist, with instructions to create it.
"""
if not self.pc.has_index(name=self.index_name):
raise ValueError(
f"Index '{self.index_name}' not found. Create it with:\n"
f"pc index create -n {self.index_name} -m cosine -c aws -r us-east-1 "
f"--model llama-text-embed-v2 --field_map text=content"
)
def ensure_index_exists(self) -> bool:
"""
Check if index exists, create it if not.
Creates the index with integrated embeddings using llama-text-embed-v2 model.
Returns:
True if index was created, False if it already existed.
"""
if self.pc.has_index(name=self.index_name):
logger.info(f"Index '{self.index_name}' already exists")
return False
logger.info(f"Creating index '{self.index_name}' with integrated embeddings...")
self.pc.create_index_for_model(
name=self.index_name,
cloud="aws",
region="us-east-1",
embed={
"model": "llama-text-embed-v2",
"field_map": {"text": "content"},
},
)
logger.info(f"Successfully created index '{self.index_name}'")
return True
def get_index_stats(self) -> dict[str, Any]:
"""
Get index statistics including vector count and namespaces.
Returns:
Dictionary with index stats including total_vector_count and namespaces.
"""
stats = self.index.describe_index_stats()
return {
"total_vector_count": stats.total_vector_count,
"namespaces": dict(stats.namespaces) if stats.namespaces else {},
}
@staticmethod
def exponential_backoff_retry(
func: Callable[[], Any], max_retries: int = 5
) -> Any:
"""
Retry function with exponential backoff on 429/5xx, fail on 4xx.
Args:
func: Function to retry (should be a callable that takes no args)
max_retries: Maximum number of retry attempts
Returns:
Result of func()
Raises:
PineconeException: If retries exhausted or non-retryable error
"""
for attempt in range(max_retries):
try:
return func()
except PineconeException as e:
status_code = getattr(e, "status", None)
# Only retry transient errors
if status_code and (status_code >= 500 or status_code == 429):
if attempt < max_retries - 1:
delay = min(2 ** attempt, 60) # Cap at 60s
logger.warning(
f"Retryable error (status {status_code}), "
f"retrying in {delay}s (attempt {attempt + 1}/{max_retries})"
)
time.sleep(delay)
else:
logger.error(
f"Max retries ({max_retries}) exceeded for retryable error"
)
raise
else:
# Don't retry client errors
logger.error(f"Non-retryable error (status {status_code}): {e}")
raise
except Exception as e:
# Non-Pinecone exceptions should not be retried
logger.error(f"Non-retryable exception: {e}")
raise
def batch_upsert(
self, records: list[PineconeRecord], batch_size: int = 96
) -> None:
"""
Upsert records in batches of specified size with rate limiting.
Args:
records: List of PineconeRecord objects to upsert
batch_size: Number of records per batch (default: 96)
"""
if not records:
logger.info("No records to upsert")
return
total_batches = (len(records) + batch_size - 1) // batch_size
logger.info(
f"Upserting {len(records)} records in {total_batches} batch(es) "
f"(batch size: {batch_size})"
)
for i in range(0, len(records), batch_size):
batch = records[i : i + batch_size]
batch_num = (i // batch_size) + 1
# Convert PineconeRecord models to dict format for Pinecone
batch_dicts = [self._record_to_dict(record) for record in batch]
logger.debug(f"Upserting batch {batch_num}/{total_batches} ({len(batch)} records)")
# Retry with exponential backoff
self.exponential_backoff_retry(
lambda b=batch_dicts: self.index.upsert_records(
namespace=self.namespace, records=b
)
)
# Rate limiting between batches
if i + batch_size < len(records):
time.sleep(0.1)
logger.info(f"Successfully upserted {len(records)} records")
@staticmethod
def _record_to_dict(record: PineconeRecord) -> dict[str, Any]:
"""
Convert PineconeRecord model to dict format for Pinecone API.
Handles optional fields by omitting them if None. Pinecone doesn't accept
null values for metadata fields, so parent_id must be omitted entirely
when None (for root nodes).
Args:
record: PineconeRecord model instance
Returns:
Dictionary ready for Pinecone upsert_records
"""
# Use by_alias=True to serialize 'id' as '_id' per model serialization_alias
record_dict = record.model_dump(exclude_none=False, by_alias=True)
# Remove None values for optional fields
optional_fields = {
"asn_identifier",
"statement_notation",
"statement_label",
"normalized_subject",
"publication_status",
"parent_id", # Must be omitted when None (Pinecone doesn't accept null)
"document_id",
"document_valid",
}
for field in optional_fields:
if record_dict.get(field) is None:
record_dict.pop(field, None)
return record_dict
def search_standards(
self,
query_text: str,
top_k: int = 5,
grade: str | None = None,
) -> list[dict]:
"""
Perform semantic search over standards.
Args:
query_text: Natural language query
top_k: Maximum number of results
grade: Optional grade filter
Returns:
List of result dictionaries with metadata and scores
"""
# Build filter dictionary dynamically
# Always filter to only leaf nodes (actual standards, not parent categories)
filter_parts = [{"is_leaf": {"$eq": True}}]
if grade:
filter_parts.append({"education_levels": {"$in": [grade]}})
filter_dict = None
if len(filter_parts) == 1:
filter_dict = filter_parts[0]
elif len(filter_parts) == 2:
filter_dict = {"$and": filter_parts}
# Build query dictionary
query_dict: dict[str, Any] = {
"inputs": {"text": query_text},
"top_k": top_k * 2, # Get more candidates for reranking
}
if filter_dict:
query_dict["filter"] = filter_dict
# Call search with reranking
results = self.index.search(
namespace=self.namespace,
query=query_dict,
rerank={"model": "bge-reranker-v2-m3", "top_n": top_k, "rank_fields": ["content"]},
)
# Parse results
hits = results.get("result", {}).get("hits", [])
parsed_results = []
for hit in hits:
result_dict = {
"_id": hit["_id"],
"score": hit["_score"],
**hit.get("fields", {}),
}
parsed_results.append(result_dict)
return parsed_results
def fetch_standard(self, standard_id: str) -> dict | None:
"""
Fetch a standard by its GUID (_id field only).
This method performs a direct lookup using Pinecone's fetch() API, which only
works with the standard's GUID (_id field). It does NOT search by statement_notation,
asn_identifier, or any other metadata fields.
Args:
standard_id: Standard GUID (_id field) - must be the exact GUID format
(e.g., "EA60C8D165F6481B90BFF782CE193F93")
Returns:
Standard dictionary with metadata, or None if not found
"""
result = self.index.fetch(ids=[standard_id], namespace=self.namespace)
# Extract vectors from FetchResponse
# FetchResponse.vectors is a dict mapping ID to Vector objects
vectors = result.vectors
if not vectors or standard_id not in vectors:
return None
vector = vectors[standard_id]
# Extract metadata from Vector object
# Vector has: id, values (embedding), and metadata (dict with all fields)
metadata = vector.metadata or {}
vector_id = vector.id
# Combine _id with all metadata fields
record_dict = {
"_id": vector_id,
**metadata,
}
return record_dict
@staticmethod
def is_uploaded(set_dir: Path) -> bool:
"""
Check for .pinecone_uploaded marker file.
Args:
set_dir: Path to standard set directory
Returns:
True if marker file exists, False otherwise
"""
marker_file = set_dir / ".pinecone_uploaded"
return marker_file.exists()
@staticmethod
def mark_uploaded(set_dir: Path) -> None:
"""
Create marker file with ISO 8601 timestamp.
Args:
set_dir: Path to standard set directory
"""
marker_file = set_dir / ".pinecone_uploaded"
timestamp = datetime.now(timezone.utc).isoformat()
marker_file.write_text(timestamp, encoding="utf-8")
logger.debug(f"Created upload marker: {marker_file}")
@staticmethod
def get_upload_timestamp(set_dir: Path) -> str | None:
"""
Read timestamp from marker file.
Args:
set_dir: Path to standard set directory
Returns:
ISO 8601 timestamp string if marker exists, None otherwise
"""
marker_file = set_dir / ".pinecone_uploaded"
if not marker_file.exists():
return None
try:
return marker_file.read_text(encoding="utf-8").strip()
except Exception as e:
logger.warning(f"Failed to read upload marker {marker_file}: {e}")
return None