|
|
"""Processor for transforming standard sets into Pinecone-ready format.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
from loguru import logger |
|
|
|
|
|
from tools.config import get_settings |
|
|
from tools.models import StandardSet, StandardSetResponse |
|
|
from tools.pinecone_models import PineconeRecord, ProcessedStandardSet |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from collections.abc import Mapping |
|
|
|
|
|
settings = get_settings() |
|
|
|
|
|
|
|
|
class StandardSetProcessor: |
|
|
"""Processes standard sets into Pinecone-ready format.""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize the processor.""" |
|
|
self.id_to_standard: dict[str, dict] = {} |
|
|
self.parent_to_children: dict[str | None, list[str]] = {} |
|
|
self.leaf_nodes: set[str] = set() |
|
|
self.root_nodes: set[str] = set() |
|
|
|
|
|
def process_standard_set(self, standard_set: StandardSet) -> ProcessedStandardSet: |
|
|
""" |
|
|
Process a standard set into Pinecone-ready records. |
|
|
|
|
|
Args: |
|
|
standard_set: The StandardSet model from the API |
|
|
|
|
|
Returns: |
|
|
ProcessedStandardSet with all records ready for Pinecone |
|
|
""" |
|
|
|
|
|
self._build_relationship_maps(standard_set.standards) |
|
|
|
|
|
|
|
|
records = [] |
|
|
for standard_id, standard in standard_set.standards.items(): |
|
|
record = self._transform_standard(standard, standard_set) |
|
|
records.append(record) |
|
|
|
|
|
return ProcessedStandardSet(records=records) |
|
|
|
|
|
def _build_relationship_maps(self, standards: dict[str, Standard]) -> None: |
|
|
""" |
|
|
Build helper data structures from all standards in the set. |
|
|
|
|
|
Args: |
|
|
standards: Dictionary mapping standard ID to Standard object |
|
|
""" |
|
|
|
|
|
standards_dict = { |
|
|
std_id: standard.model_dump() for std_id, standard in standards.items() |
|
|
} |
|
|
|
|
|
|
|
|
self.id_to_standard = self._build_id_to_standard_map(standards_dict) |
|
|
|
|
|
|
|
|
self.parent_to_children = self._build_parent_to_children_map(standards_dict) |
|
|
|
|
|
|
|
|
self.leaf_nodes = self._identify_leaf_nodes(standards_dict) |
|
|
|
|
|
|
|
|
self.root_nodes = self._identify_root_nodes(standards_dict) |
|
|
|
|
|
def _build_id_to_standard_map( |
|
|
self, standards: dict[str, dict] |
|
|
) -> dict[str, dict]: |
|
|
"""Build map of id -> standard object.""" |
|
|
return {std_id: std for std_id, std in standards.items()} |
|
|
|
|
|
def _build_parent_to_children_map( |
|
|
self, standards: dict[str, dict] |
|
|
) -> dict[str | None, list[str]]: |
|
|
""" |
|
|
Build map of parentId -> [child_ids], sorted by position ascending. |
|
|
|
|
|
Args: |
|
|
standards: Dictionary of standard ID to standard dict |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping parent ID (or None for roots) to sorted list of child IDs |
|
|
""" |
|
|
parent_map: dict[str | None, list[tuple[int, str]]] = {} |
|
|
|
|
|
for std_id, std in standards.items(): |
|
|
parent_id = std.get("parentId") |
|
|
position = std.get("position", 0) |
|
|
|
|
|
if parent_id not in parent_map: |
|
|
parent_map[parent_id] = [] |
|
|
parent_map[parent_id].append((position, std_id)) |
|
|
|
|
|
|
|
|
result: dict[str | None, list[str]] = {} |
|
|
for parent_id, children in parent_map.items(): |
|
|
sorted_children = sorted(children, key=lambda x: x[0]) |
|
|
result[parent_id] = [std_id for _, std_id in sorted_children] |
|
|
|
|
|
return result |
|
|
|
|
|
def _identify_leaf_nodes(self, standards: dict[str, dict]) -> set[str]: |
|
|
""" |
|
|
Identify leaf nodes: standards whose ID does NOT appear as any standard's parentId. |
|
|
|
|
|
Args: |
|
|
standards: Dictionary of standard ID to standard dict |
|
|
|
|
|
Returns: |
|
|
Set of standard IDs that are leaf nodes |
|
|
""" |
|
|
all_ids = set(standards.keys()) |
|
|
parent_ids = {std.get("parentId") for std in standards.values() if std.get("parentId") is not None} |
|
|
|
|
|
|
|
|
return all_ids - parent_ids |
|
|
|
|
|
def _identify_root_nodes(self, standards: dict[str, dict]) -> set[str]: |
|
|
""" |
|
|
Identify root nodes: standards where parentId is null. |
|
|
|
|
|
Args: |
|
|
standards: Dictionary of standard ID to standard dict |
|
|
|
|
|
Returns: |
|
|
Set of standard IDs that are root nodes |
|
|
""" |
|
|
return { |
|
|
std_id |
|
|
for std_id, std in standards.items() |
|
|
if std.get("parentId") is None |
|
|
} |
|
|
|
|
|
def find_root_id(self, standard: dict, id_to_standard: dict[str, dict]) -> str: |
|
|
""" |
|
|
Walk up the parent chain to find the root ancestor. |
|
|
|
|
|
Args: |
|
|
standard: The standard dict to find root for |
|
|
id_to_standard: Map of ID to standard dict |
|
|
|
|
|
Returns: |
|
|
The root ancestor's ID |
|
|
""" |
|
|
current = standard |
|
|
visited = set() |
|
|
|
|
|
while current.get("parentId") is not None: |
|
|
parent_id = current["parentId"] |
|
|
if parent_id in visited: |
|
|
break |
|
|
visited.add(parent_id) |
|
|
|
|
|
if parent_id not in id_to_standard: |
|
|
break |
|
|
current = id_to_standard[parent_id] |
|
|
|
|
|
return current["id"] |
|
|
|
|
|
def build_ordered_ancestors( |
|
|
self, standard: dict, id_to_standard: dict[str, dict] |
|
|
) -> list[str]: |
|
|
""" |
|
|
Build ancestor list ordered from root (index 0) to immediate parent (last index). |
|
|
|
|
|
Args: |
|
|
standard: The standard dict to build ancestors for |
|
|
id_to_standard: Map of ID to standard dict |
|
|
|
|
|
Returns: |
|
|
List of ancestor IDs ordered root -> immediate parent |
|
|
""" |
|
|
ancestors = [] |
|
|
current_id = standard.get("parentId") |
|
|
visited = set() |
|
|
|
|
|
while current_id is not None and current_id not in visited: |
|
|
visited.add(current_id) |
|
|
if current_id in id_to_standard: |
|
|
ancestors.append(current_id) |
|
|
current_id = id_to_standard[current_id].get("parentId") |
|
|
else: |
|
|
break |
|
|
|
|
|
ancestors.reverse() |
|
|
return ancestors |
|
|
|
|
|
def _compute_sibling_count(self, standard: dict) -> int: |
|
|
""" |
|
|
Count standards with same parent_id, excluding self. |
|
|
|
|
|
Args: |
|
|
standard: The standard dict |
|
|
|
|
|
Returns: |
|
|
Number of siblings (excluding self) |
|
|
""" |
|
|
parent_id = standard.get("parentId") |
|
|
if parent_id not in self.parent_to_children: |
|
|
return 0 |
|
|
|
|
|
siblings = self.parent_to_children[parent_id] |
|
|
|
|
|
return len([s for s in siblings if s != standard["id"]]) |
|
|
|
|
|
def _build_content_text(self, standard: dict) -> str: |
|
|
""" |
|
|
Generate content text block with full hierarchy. |
|
|
|
|
|
Format: "Depth N (notation): description" for each ancestor and self. |
|
|
|
|
|
Args: |
|
|
standard: The standard dict |
|
|
|
|
|
Returns: |
|
|
Multi-line text block with full hierarchy |
|
|
""" |
|
|
|
|
|
ancestor_ids = self.build_ordered_ancestors(standard, self.id_to_standard) |
|
|
|
|
|
|
|
|
lines = [] |
|
|
|
|
|
|
|
|
for ancestor_id in ancestor_ids: |
|
|
ancestor = self.id_to_standard[ancestor_id] |
|
|
depth = ancestor.get("depth", 0) |
|
|
description = ancestor.get("description", "") |
|
|
notation = ancestor.get("statementNotation") |
|
|
|
|
|
if notation: |
|
|
lines.append(f"Depth {depth} ({notation}): {description}") |
|
|
else: |
|
|
lines.append(f"Depth {depth}: {description}") |
|
|
|
|
|
|
|
|
depth = standard.get("depth", 0) |
|
|
description = standard.get("description", "") |
|
|
notation = standard.get("statementNotation") |
|
|
|
|
|
if notation: |
|
|
lines.append(f"Depth {depth} ({notation}): {description}") |
|
|
else: |
|
|
lines.append(f"Depth {depth}: {description}") |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
def _transform_standard( |
|
|
self, standard: Standard, standard_set: StandardSet |
|
|
) -> PineconeRecord: |
|
|
""" |
|
|
Transform a single standard into a PineconeRecord. |
|
|
|
|
|
Args: |
|
|
standard: The Standard object to transform |
|
|
standard_set: The parent StandardSet containing context |
|
|
|
|
|
Returns: |
|
|
PineconeRecord ready for Pinecone upsert |
|
|
""" |
|
|
std_dict = standard.model_dump() |
|
|
|
|
|
|
|
|
is_root = std_dict.get("parentId") is None |
|
|
root_id = ( |
|
|
std_dict["id"] if is_root else self.find_root_id(std_dict, self.id_to_standard) |
|
|
) |
|
|
ancestor_ids = self.build_ordered_ancestors(std_dict, self.id_to_standard) |
|
|
child_ids = self.parent_to_children.get(std_dict["id"], []) |
|
|
is_leaf = std_dict["id"] in self.leaf_nodes |
|
|
sibling_count = self._compute_sibling_count(std_dict) |
|
|
|
|
|
|
|
|
content = self._build_content_text(std_dict) |
|
|
|
|
|
|
|
|
parent_id = std_dict.get("parentId") |
|
|
|
|
|
|
|
|
|
|
|
record_data = { |
|
|
"id": std_dict["id"], |
|
|
"content": content, |
|
|
"standard_set_id": standard_set.id, |
|
|
"standard_set_title": standard_set.title, |
|
|
"subject": standard_set.subject, |
|
|
"normalized_subject": standard_set.normalizedSubject, |
|
|
"education_levels": standard_set.educationLevels, |
|
|
"document_id": standard_set.document.id, |
|
|
"document_valid": standard_set.document.valid, |
|
|
"publication_status": standard_set.document.publicationStatus, |
|
|
"jurisdiction_id": standard_set.jurisdiction.id, |
|
|
"jurisdiction_title": standard_set.jurisdiction.title, |
|
|
"depth": std_dict.get("depth", 0), |
|
|
"is_leaf": is_leaf, |
|
|
"is_root": is_root, |
|
|
"parent_id": parent_id, |
|
|
"root_id": root_id, |
|
|
"ancestor_ids": ancestor_ids, |
|
|
"child_ids": child_ids, |
|
|
"sibling_count": sibling_count, |
|
|
} |
|
|
|
|
|
|
|
|
if std_dict.get("asnIdentifier"): |
|
|
record_data["asn_identifier"] = std_dict["asnIdentifier"] |
|
|
if std_dict.get("statementNotation"): |
|
|
record_data["statement_notation"] = std_dict["statementNotation"] |
|
|
if std_dict.get("statementLabel"): |
|
|
record_data["statement_label"] = std_dict["statementLabel"] |
|
|
|
|
|
return PineconeRecord(**record_data) |
|
|
|
|
|
|
|
|
def process_and_save(standard_set_id: str) -> Path: |
|
|
""" |
|
|
Load data.json, process it, and save processed.json. |
|
|
|
|
|
Args: |
|
|
standard_set_id: The ID of the standard set to process |
|
|
|
|
|
Returns: |
|
|
Path to the saved processed.json file |
|
|
|
|
|
Raises: |
|
|
FileNotFoundError: If data.json doesn't exist |
|
|
ValueError: If JSON is invalid |
|
|
""" |
|
|
|
|
|
data_file = settings.standard_sets_dir / standard_set_id / "data.json" |
|
|
if not data_file.exists(): |
|
|
logger.warning(f"data.json not found for set {standard_set_id}, skipping") |
|
|
raise FileNotFoundError(f"data.json not found for set {standard_set_id}") |
|
|
|
|
|
|
|
|
try: |
|
|
with open(data_file, encoding="utf-8") as f: |
|
|
raw_data = json.load(f) |
|
|
except json.JSONDecodeError as e: |
|
|
raise ValueError(f"Invalid JSON in {data_file}: {e}") from e |
|
|
|
|
|
|
|
|
try: |
|
|
response = StandardSetResponse(**raw_data) |
|
|
standard_set = response.data |
|
|
except Exception as e: |
|
|
raise ValueError(f"Failed to parse standard set data: {e}") from e |
|
|
|
|
|
|
|
|
processor = StandardSetProcessor() |
|
|
processed_set = processor.process_standard_set(standard_set) |
|
|
|
|
|
|
|
|
processed_file = settings.standard_sets_dir / standard_set_id / "processed.json" |
|
|
processed_file.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with open(processed_file, "w", encoding="utf-8") as f: |
|
|
json.dump(processed_set.model_dump(mode="json"), f, indent=2) |
|
|
|
|
|
logger.info( |
|
|
f"Processed {standard_set_id}: {len(processed_set.records)} records" |
|
|
) |
|
|
|
|
|
return processed_file |
|
|
|
|
|
|