import json from typing import Dict, List from graphgen.bases import BaseExtractor, BaseLLMWrapper from graphgen.templates import SCHEMA_GUIDED_EXTRACTION_PROMPT from graphgen.utils import compute_dict_hash, detect_main_language, logger class SchemaGuidedExtractor(BaseExtractor): """ Use JSON/YAML Schema or Pydantic Model to guide the LLM to extract structured information from text. Usage example: schema = { "type": "legal contract", "description": "A legal contract for leasing property.", "properties": { "end_date": {"type": "string", "description": "The end date of the lease."}, "leased_space": {"type": "string", "description": "Description of the space that is being leased."}, "lessee": {"type": "string", "description": "The lessee's name (and possibly address)."}, "lessor": {"type": "string", "description": "The lessor's name (and possibly address)."}, "signing_date": {"type": "string", "description": "The date the contract was signed."}, "start_date": {"type": "string", "description": "The start date of the lease."}, "term_of_payment": {"type": "string", "description": "Description of the payment terms."}, "designated_use": {"type": "string", "description": "Description of the designated use of the property being leased."}, "extension_period": {"type": "string", "description": "Description of the extension options for the lease."}, "expiration_date_of_lease": {"type": "string", "description": "The expiration data of the lease."} }, "required": ["lessee", "lessor", "start_date", "end_date"] } extractor = SchemaGuidedExtractor(llm_client, schema) result = extractor.extract(text) """ def __init__(self, llm_client: BaseLLMWrapper, schema: dict): super().__init__(llm_client) self.schema = schema self.required_keys = self.schema.get("required") if not self.required_keys: # If no required keys are specified, use all keys from the schema as default self.required_keys = list(self.schema.get("properties", {}).keys()) def build_prompt(self, text: str) -> str: schema_explanation = "" for field, details in self.schema.get("properties", {}).items(): description = details.get("description", "No description provided.") schema_explanation += f'- "{field}": {description}\n' lang = detect_main_language(text) prompt = SCHEMA_GUIDED_EXTRACTION_PROMPT[lang].format( field=self.schema.get("name", "the document"), schema_explanation=schema_explanation, examples="", text=text, ) return prompt async def extract(self, chunk: dict) -> dict: _chunk_id = list(chunk.keys())[0] text = chunk[_chunk_id].get("content", "") prompt = self.build_prompt(text) response = await self.llm_client.generate_answer(prompt) try: extracted_info = json.loads(response) # Ensure all required keys are present for key in self.required_keys: if key not in extracted_info: extracted_info[key] = "" if any(extracted_info[key] == "" for key in self.required_keys): logger.debug("Missing required keys in extraction: %s", extracted_info) return {} main_keys_info = {key: extracted_info[key] for key in self.required_keys} logger.debug("Extracted info: %s", extracted_info) # add chunk metadata extracted_info["_chunk_id"] = _chunk_id return { compute_dict_hash(main_keys_info, prefix="extract-"): extracted_info } except json.JSONDecodeError: logger.error("Failed to parse extraction response: %s", response) return {} @staticmethod async def merge_extractions( extraction_list: List[Dict[str, dict]] ) -> Dict[str, dict]: """ Merge multiple extraction results based on their hashes. :param extraction_list: List of extraction results, each is a dict with hash as key and record as value. :return: Merged extraction results. """ merged: Dict[str, dict] = {} for ext in extraction_list: for h, rec in ext.items(): if h not in merged: merged[h] = rec.copy() else: for k, v in rec.items(): if k not in merged[h] or merged[h][k] == v: merged[h][k] = v else: merged[h][k] = f"{merged[h][k]}{v}" return merged