File size: 5,026 Bytes
283e483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ad3d05
 
 
283e483
 
 
 
 
 
 
 
 
 
 
 
 
8ad3d05
 
 
 
 
 
 
283e483
 
 
 
8ad3d05
283e483
8ad3d05
283e483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
from typing import Dict, List

from graphgen.bases import BaseExtractor, BaseLLMWrapper
from graphgen.templates import SCHEMA_GUIDED_EXTRACTION_PROMPT
from graphgen.utils import compute_dict_hash, detect_main_language, logger


class SchemaGuidedExtractor(BaseExtractor):
    """
    Use JSON/YAML Schema or Pydantic Model to guide the LLM to extract structured information from text.

    Usage example:
        schema = {
                "type": "legal contract",
                "description": "A legal contract for leasing property.",
                "properties": {
                    "end_date": {"type": "string", "description": "The end date of the lease."},
                    "leased_space": {"type": "string", "description": "Description of the space that is being leased."},
                    "lessee": {"type": "string", "description": "The lessee's name (and possibly address)."},
                    "lessor": {"type": "string", "description": "The lessor's name (and possibly address)."},
                    "signing_date": {"type": "string", "description": "The date the contract was signed."},
                    "start_date": {"type": "string", "description": "The start date of the lease."},
                    "term_of_payment": {"type": "string", "description": "Description of the payment terms."},
                    "designated_use": {"type": "string",
                    "description": "Description of the designated use of the property being leased."},
                    "extension_period": {"type": "string",
                    "description": "Description of the extension options for the lease."},
                    "expiration_date_of_lease": {"type": "string", "description": "The expiration data of the lease."}
                },
                "required": ["lessee", "lessor", "start_date", "end_date"]
            }
        extractor = SchemaGuidedExtractor(llm_client, schema)
        result = extractor.extract(text)

    """

    def __init__(self, llm_client: BaseLLMWrapper, schema: dict):
        super().__init__(llm_client)
        self.schema = schema
        self.required_keys = self.schema.get("required")
        if not self.required_keys:
            # If no required keys are specified, use all keys from the schema as default
            self.required_keys = list(self.schema.get("properties", {}).keys())

    def build_prompt(self, text: str) -> str:
        schema_explanation = ""
        for field, details in self.schema.get("properties", {}).items():
            description = details.get("description", "No description provided.")
            schema_explanation += f'- "{field}": {description}\n'

        lang = detect_main_language(text)

        prompt = SCHEMA_GUIDED_EXTRACTION_PROMPT[lang].format(
            field=self.schema.get("name", "the document"),
            schema_explanation=schema_explanation,
            examples="",
            text=text,
        )
        return prompt

    async def extract(self, chunk: dict) -> dict:
        _chunk_id = list(chunk.keys())[0]
        text = chunk[_chunk_id].get("content", "")

        prompt = self.build_prompt(text)
        response = await self.llm_client.generate_answer(prompt)
        try:
            extracted_info = json.loads(response)
            # Ensure all required keys are present
            for key in self.required_keys:
                if key not in extracted_info:
                    extracted_info[key] = ""
            if any(extracted_info[key] == "" for key in self.required_keys):
                logger.debug("Missing required keys in extraction: %s", extracted_info)
                return {}
            main_keys_info = {key: extracted_info[key] for key in self.required_keys}
            logger.debug("Extracted info: %s", extracted_info)

            # add chunk metadata
            extracted_info["_chunk_id"] = _chunk_id

            return {
                compute_dict_hash(main_keys_info, prefix="extract-"): extracted_info
            }
        except json.JSONDecodeError:
            logger.error("Failed to parse extraction response: %s", response)
            return {}

    @staticmethod
    async def merge_extractions(
        extraction_list: List[Dict[str, dict]]
    ) -> Dict[str, dict]:
        """
        Merge multiple extraction results based on their hashes.
        :param extraction_list: List of extraction results, each is a dict with hash as key and record as value.
        :return: Merged extraction results.
        """
        merged: Dict[str, dict] = {}
        for ext in extraction_list:
            for h, rec in ext.items():
                if h not in merged:
                    merged[h] = rec.copy()
                else:
                    for k, v in rec.items():
                        if k not in merged[h] or merged[h][k] == v:
                            merged[h][k] = v
                        else:
                            merged[h][k] = f"{merged[h][k]}<SEP>{v}"
        return merged