"""Processor for transforming standard sets into Pinecone-ready format.""" from __future__ import annotations import json from pathlib import Path from typing import TYPE_CHECKING from loguru import logger from tools.config import get_settings from tools.models import StandardSet, StandardSetResponse from tools.pinecone_models import PineconeRecord, ProcessedStandardSet if TYPE_CHECKING: from collections.abc import Mapping settings = get_settings() class StandardSetProcessor: """Processes standard sets into Pinecone-ready format.""" def __init__(self): """Initialize the processor.""" self.id_to_standard: dict[str, dict] = {} self.parent_to_children: dict[str | None, list[str]] = {} self.leaf_nodes: set[str] = set() self.root_nodes: set[str] = set() def process_standard_set(self, standard_set: StandardSet) -> ProcessedStandardSet: """ Process a standard set into Pinecone-ready records. Args: standard_set: The StandardSet model from the API Returns: ProcessedStandardSet with all records ready for Pinecone """ # Build relationship maps from all standards self._build_relationship_maps(standard_set.standards) # Process each standard into a PineconeRecord records = [] for standard_id, standard in standard_set.standards.items(): record = self._transform_standard(standard, standard_set) records.append(record) return ProcessedStandardSet(records=records) def _build_relationship_maps(self, standards: dict[str, Standard]) -> None: """ Build helper data structures from all standards in the set. Args: standards: Dictionary mapping standard ID to Standard object """ # Convert to dict format for easier manipulation standards_dict = { std_id: standard.model_dump() for std_id, standard in standards.items() } # Build ID-to-standard map self.id_to_standard = self._build_id_to_standard_map(standards_dict) # Build parent-to-children map (sorted by position) self.parent_to_children = self._build_parent_to_children_map(standards_dict) # Identify leaf nodes self.leaf_nodes = self._identify_leaf_nodes(standards_dict) # Identify root nodes self.root_nodes = self._identify_root_nodes(standards_dict) def _build_id_to_standard_map( self, standards: dict[str, dict] ) -> dict[str, dict]: """Build map of id -> standard object.""" return {std_id: std for std_id, std in standards.items()} def _build_parent_to_children_map( self, standards: dict[str, dict] ) -> dict[str | None, list[str]]: """ Build map of parentId -> [child_ids], sorted by position ascending. Args: standards: Dictionary of standard ID to standard dict Returns: Dictionary mapping parent ID (or None for roots) to sorted list of child IDs """ parent_map: dict[str | None, list[tuple[int, str]]] = {} for std_id, std in standards.items(): parent_id = std.get("parentId") position = std.get("position", 0) if parent_id not in parent_map: parent_map[parent_id] = [] parent_map[parent_id].append((position, std_id)) # Sort each list by position and extract just the IDs result: dict[str | None, list[str]] = {} for parent_id, children in parent_map.items(): sorted_children = sorted(children, key=lambda x: x[0]) result[parent_id] = [std_id for _, std_id in sorted_children] return result def _identify_leaf_nodes(self, standards: dict[str, dict]) -> set[str]: """ Identify leaf nodes: standards whose ID does NOT appear as any standard's parentId. Args: standards: Dictionary of standard ID to standard dict Returns: Set of standard IDs that are leaf nodes """ all_ids = set(standards.keys()) parent_ids = {std.get("parentId") for std in standards.values() if std.get("parentId") is not None} # Leaf nodes are IDs that are NOT in parent_ids return all_ids - parent_ids def _identify_root_nodes(self, standards: dict[str, dict]) -> set[str]: """ Identify root nodes: standards where parentId is null. Args: standards: Dictionary of standard ID to standard dict Returns: Set of standard IDs that are root nodes """ return { std_id for std_id, std in standards.items() if std.get("parentId") is None } def find_root_id(self, standard: dict, id_to_standard: dict[str, dict]) -> str: """ Walk up the parent chain to find the root ancestor. Args: standard: The standard dict to find root for id_to_standard: Map of ID to standard dict Returns: The root ancestor's ID """ current = standard visited = set() # Prevent infinite loops from bad data while current.get("parentId") is not None: parent_id = current["parentId"] if parent_id in visited: break # Circular reference protection visited.add(parent_id) if parent_id not in id_to_standard: break # Parent not found, use current as root current = id_to_standard[parent_id] return current["id"] def build_ordered_ancestors( self, standard: dict, id_to_standard: dict[str, dict] ) -> list[str]: """ Build ancestor list ordered from root (index 0) to immediate parent (last index). Args: standard: The standard dict to build ancestors for id_to_standard: Map of ID to standard dict Returns: List of ancestor IDs ordered root -> immediate parent """ ancestors = [] current_id = standard.get("parentId") visited = set() while current_id is not None and current_id not in visited: visited.add(current_id) if current_id in id_to_standard: ancestors.append(current_id) current_id = id_to_standard[current_id].get("parentId") else: break ancestors.reverse() # Now ordered root → immediate parent return ancestors def _compute_sibling_count(self, standard: dict) -> int: """ Count standards with same parent_id, excluding self. Args: standard: The standard dict Returns: Number of siblings (excluding self) """ parent_id = standard.get("parentId") if parent_id not in self.parent_to_children: return 0 siblings = self.parent_to_children[parent_id] # Exclude self from count return len([s for s in siblings if s != standard["id"]]) def _build_content_text(self, standard: dict) -> str: """ Generate content text block with full hierarchy. Format: "Depth N (notation): description" for each ancestor and self. Args: standard: The standard dict Returns: Multi-line text block with full hierarchy """ # Build ordered ancestor chain ancestor_ids = self.build_ordered_ancestors(standard, self.id_to_standard) # Build lines from root to current standard lines = [] # Add ancestor lines for ancestor_id in ancestor_ids: ancestor = self.id_to_standard[ancestor_id] depth = ancestor.get("depth", 0) description = ancestor.get("description", "") notation = ancestor.get("statementNotation") if notation: lines.append(f"Depth {depth} ({notation}): {description}") else: lines.append(f"Depth {depth}: {description}") # Add current standard line depth = standard.get("depth", 0) description = standard.get("description", "") notation = standard.get("statementNotation") if notation: lines.append(f"Depth {depth} ({notation}): {description}") else: lines.append(f"Depth {depth}: {description}") return "\n".join(lines) def _transform_standard( self, standard: Standard, standard_set: StandardSet ) -> PineconeRecord: """ Transform a single standard into a PineconeRecord. Args: standard: The Standard object to transform standard_set: The parent StandardSet containing context Returns: PineconeRecord ready for Pinecone upsert """ std_dict = standard.model_dump() # Compute hierarchy relationships is_root = std_dict.get("parentId") is None root_id = ( std_dict["id"] if is_root else self.find_root_id(std_dict, self.id_to_standard) ) ancestor_ids = self.build_ordered_ancestors(std_dict, self.id_to_standard) child_ids = self.parent_to_children.get(std_dict["id"], []) is_leaf = std_dict["id"] in self.leaf_nodes sibling_count = self._compute_sibling_count(std_dict) # Build content text content = self._build_content_text(std_dict) # Extract standard set context parent_id = std_dict.get("parentId") # Keep as None if null # Build record with all fields # Note: Use "id" not "_id" - Pydantic handles serialization alias automatically record_data = { "id": std_dict["id"], "content": content, "standard_set_id": standard_set.id, "standard_set_title": standard_set.title, "subject": standard_set.subject, "normalized_subject": standard_set.normalizedSubject, # Optional, can be None "education_levels": standard_set.educationLevels, "document_id": standard_set.document.id, "document_valid": standard_set.document.valid, "publication_status": standard_set.document.publicationStatus, # Optional, can be None "jurisdiction_id": standard_set.jurisdiction.id, "jurisdiction_title": standard_set.jurisdiction.title, "depth": std_dict.get("depth", 0), "is_leaf": is_leaf, "is_root": is_root, "parent_id": parent_id, "root_id": root_id, "ancestor_ids": ancestor_ids, "child_ids": child_ids, "sibling_count": sibling_count, } # Add optional fields only if present if std_dict.get("asnIdentifier"): record_data["asn_identifier"] = std_dict["asnIdentifier"] if std_dict.get("statementNotation"): record_data["statement_notation"] = std_dict["statementNotation"] if std_dict.get("statementLabel"): record_data["statement_label"] = std_dict["statementLabel"] return PineconeRecord(**record_data) def process_and_save(standard_set_id: str) -> Path: """ Load data.json, process it, and save processed.json. Args: standard_set_id: The ID of the standard set to process Returns: Path to the saved processed.json file Raises: FileNotFoundError: If data.json doesn't exist ValueError: If JSON is invalid """ # Locate data.json data_file = settings.standard_sets_dir / standard_set_id / "data.json" if not data_file.exists(): logger.warning(f"data.json not found for set {standard_set_id}, skipping") raise FileNotFoundError(f"data.json not found for set {standard_set_id}") # Load and parse JSON try: with open(data_file, encoding="utf-8") as f: raw_data = json.load(f) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON in {data_file}: {e}") from e # Parse into Pydantic model try: response = StandardSetResponse(**raw_data) standard_set = response.data except Exception as e: raise ValueError(f"Failed to parse standard set data: {e}") from e # Process the standard set processor = StandardSetProcessor() processed_set = processor.process_standard_set(standard_set) # Save processed.json processed_file = settings.standard_sets_dir / standard_set_id / "processed.json" processed_file.parent.mkdir(parents=True, exist_ok=True) with open(processed_file, "w", encoding="utf-8") as f: json.dump(processed_set.model_dump(mode="json"), f, indent=2) logger.info( f"Processed {standard_set_id}: {len(processed_set.records)} records" ) return processed_file