Spaces:
Running
Running
File size: 3,499 Bytes
0b9d8c7 8c66169 0b9d8c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import re
from collections import defaultdict
from typing import Dict, List, Tuple
from graphgen.bases import Chunk
from graphgen.templates import MMKG_EXTRACTION_PROMPT
from graphgen.utils import (
detect_main_language,
handle_single_entity_extraction,
handle_single_relationship_extraction,
logger,
split_string_by_multi_markers,
)
from .light_rag_kg_builder import LightRAGKGBuilder
class MMKGBuilder(LightRAGKGBuilder):
async def extract(
self, chunk: Chunk
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
"""
Extract entities and relationships from a single multi-modal chunk using the LLM client.
Expect to get a mini graph which contains a central multi-modal entity
and its related text entities and relationships.
Like:
(image: "image_of_eiffel_tower") --[located_in]--> (text: "Paris")
(image: "image_of_eiffel_tower") --[built_in]--> (text: "1889")
(text: "Eiffel Tower") --[height]--> (text: "324 meters")
:param chunk
"""
chunk_id = chunk.id
chunk_type = chunk.type # image | table | formula | ...
metadata = chunk.metadata
# choose different extraction strategies based on chunk type
if chunk_type == "image":
image_caption = "\n".join(metadata.get("image_caption", ""))
language = detect_main_language(image_caption)
prompt_template = MMKG_EXTRACTION_PROMPT[language].format(
**MMKG_EXTRACTION_PROMPT["FORMAT"],
chunk_type=chunk_type,
chunk_id=chunk_id,
chunk_text=image_caption,
)
result = await self.llm_client.generate_answer(prompt_template)
logger.debug("Image chunk extraction result: %s", result)
# parse the result
records = split_string_by_multi_markers(
result,
[
MMKG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
MMKG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
],
)
nodes = defaultdict(list)
edges = defaultdict(list)
for record in records:
match = re.search(r"\((.*)\)", record)
if not match:
continue
inner = match.group(1)
attributes = split_string_by_multi_markers(
inner, [MMKG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
)
entity = await handle_single_entity_extraction(attributes, chunk_id)
if entity is not None:
nodes[entity["entity_name"]].append(entity)
continue
relation = await handle_single_relationship_extraction(
attributes, chunk_id
)
if relation is not None:
key = (relation["src_id"], relation["tgt_id"])
edges[key].append(relation)
return dict(nodes), dict(edges)
if chunk_type == "table":
pass # TODO: implement table-based entity and relationship extraction
if chunk_type == "formula":
pass # TODO: implement formula-based entity and relationship extraction
logger.error("Unsupported chunk type for MMKGBuilder: %s", chunk_type)
return defaultdict(list), defaultdict(list)
|