Spaces:
Sleeping
Sleeping
github-actions[bot]
commited on
Commit
·
0b9d8c7
1
Parent(s):
e4316f1
Auto-sync from demo at Thu Oct 23 11:07:54 UTC 2025
Browse files- graphgen/bases/base_reader.py +45 -0
- graphgen/bases/datatypes.py +10 -0
- graphgen/configs/vqa_config.yaml +5 -9
- graphgen/graphgen.py +111 -55
- graphgen/models/__init__.py +2 -1
- graphgen/models/generator/aggregated_generator.py +3 -3
- graphgen/models/generator/atomic_generator.py +2 -2
- graphgen/models/generator/cot_generator.py +3 -3
- graphgen/models/generator/multi_hop_generator.py +2 -2
- graphgen/models/generator/vqa_generator.py +122 -7
- graphgen/models/kg_builder/__init__.py +1 -0
- graphgen/models/kg_builder/light_rag_kg_builder.py +3 -10
- graphgen/models/kg_builder/mm_kg_builder.py +93 -0
- graphgen/models/partitioner/__init__.py +1 -0
- graphgen/models/partitioner/anchor_bfs_partitioner.py +128 -0
- graphgen/models/reader/csv_reader.py +6 -3
- graphgen/models/reader/json_reader.py +2 -2
- graphgen/models/reader/jsonl_reader.py +3 -4
- graphgen/models/reader/pdf_reader.py +1 -3
- graphgen/models/reader/txt_reader.py +1 -1
- graphgen/operators/__init__.py +1 -1
- graphgen/operators/build_kg/__init__.py +2 -1
- graphgen/operators/build_kg/build_mm_kg.py +56 -0
- graphgen/operators/build_kg/{build_kg.py → build_text_kg.py} +1 -1
- graphgen/operators/judge.py +5 -5
- graphgen/operators/partition/partition_kg.py +21 -1
- graphgen/operators/split/split_chunks.py +23 -17
- graphgen/templates/__init__.py +2 -2
- graphgen/templates/generation/__init__.py +1 -0
- graphgen/templates/generation/aggregated_generation.py +4 -4
- graphgen/templates/generation/vqa_generation.py +104 -0
- graphgen/templates/kg/__init__.py +3 -0
- graphgen/templates/{kg_extraction.py → kg/kg_extraction.py} +5 -7
- graphgen/templates/{kg_summarization.py → kg/kg_summarization.py} +4 -9
- graphgen/templates/kg/mm_kg_extraction.py +131 -0
- graphgen/utils/__init__.py +1 -1
- graphgen/utils/detect_lang.py +10 -9
- graphgen/utils/hash.py +16 -0
- graphgen/utils/log.py +9 -4
graphgen/bases/base_reader.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
|
|
| 1 |
from abc import ABC, abstractmethod
|
| 2 |
from typing import Any, Dict, List
|
| 3 |
|
|
|
|
|
|
|
| 4 |
|
| 5 |
class BaseReader(ABC):
|
| 6 |
"""
|
|
@@ -18,3 +21,45 @@ class BaseReader(ABC):
|
|
| 18 |
:param file_path: Path to the input file.
|
| 19 |
:return: List of dictionaries containing the data.
|
| 20 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
from abc import ABC, abstractmethod
|
| 3 |
from typing import Any, Dict, List
|
| 4 |
|
| 5 |
+
import requests
|
| 6 |
+
|
| 7 |
|
| 8 |
class BaseReader(ABC):
|
| 9 |
"""
|
|
|
|
| 21 |
:param file_path: Path to the input file.
|
| 22 |
:return: List of dictionaries containing the data.
|
| 23 |
"""
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def filter(data: List[dict]) -> List[dict]:
|
| 27 |
+
"""
|
| 28 |
+
Filter out entries with empty or missing text in the specified column.
|
| 29 |
+
|
| 30 |
+
:param data: List of dictionaries containing the data.
|
| 31 |
+
:return: Filtered list of dictionaries.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
|
| 35 |
+
"""
|
| 36 |
+
Check if an image exists at the given local path or URL.
|
| 37 |
+
:param path_or_url: Local file path or remote URL of the image.
|
| 38 |
+
:param timeout: Timeout for remote URL requests in seconds.
|
| 39 |
+
:return: True if the image exists, False otherwise.
|
| 40 |
+
"""
|
| 41 |
+
if not path_or_url:
|
| 42 |
+
return False
|
| 43 |
+
if not path_or_url.startswith(("http://", "https://", "ftp://")):
|
| 44 |
+
path = path_or_url.replace("file://", "", 1)
|
| 45 |
+
path = os.path.abspath(path)
|
| 46 |
+
return os.path.isfile(path)
|
| 47 |
+
try:
|
| 48 |
+
resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
|
| 49 |
+
return resp.status_code == 200
|
| 50 |
+
except requests.RequestException:
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
filtered_data = []
|
| 54 |
+
for item in data:
|
| 55 |
+
if item.get("type") == "text":
|
| 56 |
+
content = item.get("content", "").strip()
|
| 57 |
+
if content:
|
| 58 |
+
filtered_data.append(item)
|
| 59 |
+
elif item.get("type") in ("image", "table", "equation"):
|
| 60 |
+
img_path = item.get("img_path")
|
| 61 |
+
if _image_exists(img_path):
|
| 62 |
+
filtered_data.append(item)
|
| 63 |
+
else:
|
| 64 |
+
filtered_data.append(item)
|
| 65 |
+
return filtered_data
|
graphgen/bases/datatypes.py
CHANGED
|
@@ -7,8 +7,18 @@ from typing import List, Union
|
|
| 7 |
class Chunk:
|
| 8 |
id: str
|
| 9 |
content: str
|
|
|
|
| 10 |
metadata: dict = field(default_factory=dict)
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
@dataclass
|
| 14 |
class QAPair:
|
|
|
|
| 7 |
class Chunk:
|
| 8 |
id: str
|
| 9 |
content: str
|
| 10 |
+
type: str
|
| 11 |
metadata: dict = field(default_factory=dict)
|
| 12 |
|
| 13 |
+
@staticmethod
|
| 14 |
+
def from_dict(key: str, data: dict) -> "Chunk":
|
| 15 |
+
return Chunk(
|
| 16 |
+
id=key,
|
| 17 |
+
content=data.get("content", ""),
|
| 18 |
+
type=data.get("type", "unknown"),
|
| 19 |
+
metadata={k: v for k, v in data.items() if k != "content"},
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
|
| 23 |
@dataclass
|
| 24 |
class QAPair:
|
graphgen/configs/vqa_config.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
read:
|
| 2 |
-
input_file: resources/input_examples/
|
| 3 |
split:
|
| 4 |
chunk_size: 1024 # chunk size for text splitting
|
| 5 |
chunk_overlap: 100 # chunk overlap for text splitting
|
|
@@ -7,16 +7,12 @@ search: # web search configuration
|
|
| 7 |
enabled: false # whether to enable web search
|
| 8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
| 10 |
-
enabled:
|
| 11 |
-
quiz_samples: 2 # number of quiz samples to generate
|
| 12 |
-
re_judge: false # whether to re-judge the existing quiz samples
|
| 13 |
partition: # graph partition configuration
|
| 14 |
-
method:
|
| 15 |
method_params:
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
max_tokens_per_community: 10240 # max tokens per community
|
| 19 |
-
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
|
| 20 |
generate:
|
| 21 |
mode: vqa # atomic, aggregated, multi_hop, cot, vqa
|
| 22 |
data_format: ChatML # Alpaca, Sharegpt, ChatML
|
|
|
|
| 1 |
read:
|
| 2 |
+
input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
|
| 3 |
split:
|
| 4 |
chunk_size: 1024 # chunk size for text splitting
|
| 5 |
chunk_overlap: 100 # chunk overlap for text splitting
|
|
|
|
| 7 |
enabled: false # whether to enable web search
|
| 8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
| 10 |
+
enabled: false
|
|
|
|
|
|
|
| 11 |
partition: # graph partition configuration
|
| 12 |
+
method: anchor_bfs # partition method
|
| 13 |
method_params:
|
| 14 |
+
anchor_type: image # node type to select anchor nodes
|
| 15 |
+
max_units_per_community: 10 # atomic partition, one node or edge per community
|
|
|
|
|
|
|
| 16 |
generate:
|
| 17 |
mode: vqa # atomic, aggregated, multi_hop, cot, vqa
|
| 18 |
data_format: ChatML # Alpaca, Sharegpt, ChatML
|
graphgen/graphgen.py
CHANGED
|
@@ -16,7 +16,8 @@ from graphgen.models import (
|
|
| 16 |
Tokenizer,
|
| 17 |
)
|
| 18 |
from graphgen.operators import (
|
| 19 |
-
|
|
|
|
| 20 |
chunk_documents,
|
| 21 |
generate_qas,
|
| 22 |
judge_statement,
|
|
@@ -25,7 +26,7 @@ from graphgen.operators import (
|
|
| 25 |
read_files,
|
| 26 |
search_all,
|
| 27 |
)
|
| 28 |
-
from graphgen.utils import async_to_sync_method,
|
| 29 |
|
| 30 |
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 31 |
|
|
@@ -68,8 +69,8 @@ class GraphGen:
|
|
| 68 |
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
|
| 69 |
self.working_dir, namespace="full_docs"
|
| 70 |
)
|
| 71 |
-
self.
|
| 72 |
-
self.working_dir, namespace="
|
| 73 |
)
|
| 74 |
self.graph_storage: NetworkXStorage = NetworkXStorage(
|
| 75 |
self.working_dir, namespace="graph"
|
|
@@ -96,70 +97,122 @@ class GraphGen:
|
|
| 96 |
logger.warning("No data to process")
|
| 97 |
return
|
| 98 |
|
|
|
|
|
|
|
| 99 |
# TODO: configurable whether to use coreference resolution
|
| 100 |
|
| 101 |
-
|
| 102 |
-
assert isinstance(data, list) and isinstance(data[0], dict)
|
| 103 |
-
new_docs = {
|
| 104 |
-
compute_content_hash(doc["content"], prefix="doc-"): {
|
| 105 |
-
"content": doc["content"]
|
| 106 |
-
}
|
| 107 |
-
for doc in data
|
| 108 |
-
if doc.get("type", "text") == "text"
|
| 109 |
-
}
|
| 110 |
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
|
| 111 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
|
| 114 |
-
logger.warning("All docs are already in the storage")
|
| 115 |
-
return
|
| 116 |
-
logger.info("[New Docs] inserting %d docs", len(new_docs))
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
_add_entities_and_relations = await build_kg(
|
| 144 |
-
llm_client=self.synthesizer_llm_client,
|
| 145 |
-
kg_instance=self.graph_storage,
|
| 146 |
-
chunks=[
|
| 147 |
-
Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items()
|
| 148 |
-
],
|
| 149 |
-
progress_bar=self.progress_bar,
|
| 150 |
-
)
|
| 151 |
-
if not _add_entities_and_relations:
|
| 152 |
-
logger.warning("No entities or relations extracted")
|
| 153 |
-
return
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
async def _insert_done(self):
|
| 159 |
tasks = []
|
| 160 |
for storage_instance in [
|
| 161 |
self.full_docs_storage,
|
| 162 |
-
self.
|
| 163 |
self.graph_storage,
|
| 164 |
self.search_storage,
|
| 165 |
]:
|
|
@@ -233,7 +286,10 @@ class GraphGen:
|
|
| 233 |
async def generate(self, partition_config: Dict, generate_config: Dict):
|
| 234 |
# Step 1: partition the graph
|
| 235 |
batches = await partition_kg(
|
| 236 |
-
self.graph_storage,
|
|
|
|
|
|
|
|
|
|
| 237 |
)
|
| 238 |
|
| 239 |
# Step 2: generate QA pairs
|
|
@@ -255,7 +311,7 @@ class GraphGen:
|
|
| 255 |
@async_to_sync_method
|
| 256 |
async def clear(self):
|
| 257 |
await self.full_docs_storage.drop()
|
| 258 |
-
await self.
|
| 259 |
await self.search_storage.drop()
|
| 260 |
await self.graph_storage.clear()
|
| 261 |
await self.rephrase_storage.drop()
|
|
|
|
| 16 |
Tokenizer,
|
| 17 |
)
|
| 18 |
from graphgen.operators import (
|
| 19 |
+
build_mm_kg,
|
| 20 |
+
build_text_kg,
|
| 21 |
chunk_documents,
|
| 22 |
generate_qas,
|
| 23 |
judge_statement,
|
|
|
|
| 26 |
read_files,
|
| 27 |
search_all,
|
| 28 |
)
|
| 29 |
+
from graphgen.utils import async_to_sync_method, compute_mm_hash, logger
|
| 30 |
|
| 31 |
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 32 |
|
|
|
|
| 69 |
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
|
| 70 |
self.working_dir, namespace="full_docs"
|
| 71 |
)
|
| 72 |
+
self.chunks_storage: JsonKVStorage = JsonKVStorage(
|
| 73 |
+
self.working_dir, namespace="chunks"
|
| 74 |
)
|
| 75 |
self.graph_storage: NetworkXStorage = NetworkXStorage(
|
| 76 |
self.working_dir, namespace="graph"
|
|
|
|
| 97 |
logger.warning("No data to process")
|
| 98 |
return
|
| 99 |
|
| 100 |
+
assert isinstance(data, list) and isinstance(data[0], dict)
|
| 101 |
+
|
| 102 |
# TODO: configurable whether to use coreference resolution
|
| 103 |
|
| 104 |
+
new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
_add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
|
| 106 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
| 107 |
+
new_text_docs = {k: v for k, v in new_docs.items() if v.get("type") == "text"}
|
| 108 |
+
new_mm_docs = {k: v for k, v in new_docs.items() if v.get("type") != "text"}
|
| 109 |
|
| 110 |
+
await self.full_docs_storage.upsert(new_docs)
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
async def _insert_text_docs(text_docs):
|
| 113 |
+
if len(text_docs) == 0:
|
| 114 |
+
logger.warning("All text docs are already in the storage")
|
| 115 |
+
return
|
| 116 |
+
logger.info("[New Docs] inserting %d text docs", len(text_docs))
|
| 117 |
+
# Step 2.1: Split chunks and filter existing ones
|
| 118 |
+
inserting_chunks = await chunk_documents(
|
| 119 |
+
text_docs,
|
| 120 |
+
split_config["chunk_size"],
|
| 121 |
+
split_config["chunk_overlap"],
|
| 122 |
+
self.tokenizer_instance,
|
| 123 |
+
self.progress_bar,
|
| 124 |
+
)
|
| 125 |
|
| 126 |
+
_add_chunk_keys = await self.chunks_storage.filter_keys(
|
| 127 |
+
list(inserting_chunks.keys())
|
| 128 |
+
)
|
| 129 |
+
inserting_chunks = {
|
| 130 |
+
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
| 131 |
+
}
|
| 132 |
|
| 133 |
+
if len(inserting_chunks) == 0:
|
| 134 |
+
logger.warning("All text chunks are already in the storage")
|
| 135 |
+
return
|
| 136 |
+
|
| 137 |
+
logger.info("[New Chunks] inserting %d text chunks", len(inserting_chunks))
|
| 138 |
+
await self.chunks_storage.upsert(inserting_chunks)
|
| 139 |
+
|
| 140 |
+
# Step 2.2: Extract entities and relations from text chunks
|
| 141 |
+
logger.info("[Text Entity and Relation Extraction] processing ...")
|
| 142 |
+
_add_entities_and_relations = await build_text_kg(
|
| 143 |
+
llm_client=self.synthesizer_llm_client,
|
| 144 |
+
kg_instance=self.graph_storage,
|
| 145 |
+
chunks=[
|
| 146 |
+
Chunk(id=k, content=v["content"], type="text")
|
| 147 |
+
for k, v in inserting_chunks.items()
|
| 148 |
+
],
|
| 149 |
+
progress_bar=self.progress_bar,
|
| 150 |
+
)
|
| 151 |
+
if not _add_entities_and_relations:
|
| 152 |
+
logger.warning("No entities or relations extracted from text chunks")
|
| 153 |
+
return
|
| 154 |
+
|
| 155 |
+
await self._insert_done()
|
| 156 |
+
return _add_entities_and_relations
|
| 157 |
+
|
| 158 |
+
async def _insert_multi_modal_docs(mm_docs):
|
| 159 |
+
if len(mm_docs) == 0:
|
| 160 |
+
logger.warning("No multi-modal documents to insert")
|
| 161 |
+
return
|
| 162 |
+
|
| 163 |
+
logger.info("[New Docs] inserting %d multi-modal docs", len(mm_docs))
|
| 164 |
+
|
| 165 |
+
# Step 3.1: Transform multi-modal documents into chunks and filter existing ones
|
| 166 |
+
inserting_chunks = await chunk_documents(
|
| 167 |
+
mm_docs,
|
| 168 |
+
split_config["chunk_size"],
|
| 169 |
+
split_config["chunk_overlap"],
|
| 170 |
+
self.tokenizer_instance,
|
| 171 |
+
self.progress_bar,
|
| 172 |
+
)
|
| 173 |
|
| 174 |
+
_add_chunk_keys = await self.chunks_storage.filter_keys(
|
| 175 |
+
list(inserting_chunks.keys())
|
| 176 |
+
)
|
| 177 |
+
inserting_chunks = {
|
| 178 |
+
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
| 179 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
+
if len(inserting_chunks) == 0:
|
| 182 |
+
logger.warning("All multi-modal chunks are already in the storage")
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
logger.info(
|
| 186 |
+
"[New Chunks] inserting %d multimodal chunks", len(inserting_chunks)
|
| 187 |
+
)
|
| 188 |
+
await self.chunks_storage.upsert(inserting_chunks)
|
| 189 |
+
|
| 190 |
+
# Step 3.2: Extract multi-modal entities and relations from chunks
|
| 191 |
+
logger.info("[Multi-modal Entity and Relation Extraction] processing ...")
|
| 192 |
+
_add_entities_and_relations = await build_mm_kg(
|
| 193 |
+
llm_client=self.synthesizer_llm_client,
|
| 194 |
+
kg_instance=self.graph_storage,
|
| 195 |
+
chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()],
|
| 196 |
+
progress_bar=self.progress_bar,
|
| 197 |
+
)
|
| 198 |
+
if not _add_entities_and_relations:
|
| 199 |
+
logger.warning(
|
| 200 |
+
"No entities or relations extracted from multi-modal chunks"
|
| 201 |
+
)
|
| 202 |
+
return
|
| 203 |
+
await self._insert_done()
|
| 204 |
+
return _add_entities_and_relations
|
| 205 |
+
|
| 206 |
+
# Step 2: Insert text documents
|
| 207 |
+
await _insert_text_docs(new_text_docs)
|
| 208 |
+
# Step 3: Insert multi-modal documents
|
| 209 |
+
await _insert_multi_modal_docs(new_mm_docs)
|
| 210 |
|
| 211 |
async def _insert_done(self):
|
| 212 |
tasks = []
|
| 213 |
for storage_instance in [
|
| 214 |
self.full_docs_storage,
|
| 215 |
+
self.chunks_storage,
|
| 216 |
self.graph_storage,
|
| 217 |
self.search_storage,
|
| 218 |
]:
|
|
|
|
| 286 |
async def generate(self, partition_config: Dict, generate_config: Dict):
|
| 287 |
# Step 1: partition the graph
|
| 288 |
batches = await partition_kg(
|
| 289 |
+
self.graph_storage,
|
| 290 |
+
self.chunks_storage,
|
| 291 |
+
self.tokenizer_instance,
|
| 292 |
+
partition_config,
|
| 293 |
)
|
| 294 |
|
| 295 |
# Step 2: generate QA pairs
|
|
|
|
| 311 |
@async_to_sync_method
|
| 312 |
async def clear(self):
|
| 313 |
await self.full_docs_storage.drop()
|
| 314 |
+
await self.chunks_storage.drop()
|
| 315 |
await self.search_storage.drop()
|
| 316 |
await self.graph_storage.clear()
|
| 317 |
await self.rephrase_storage.drop()
|
graphgen/models/__init__.py
CHANGED
|
@@ -6,10 +6,11 @@ from .generator import (
|
|
| 6 |
MultiHopGenerator,
|
| 7 |
VQAGenerator,
|
| 8 |
)
|
| 9 |
-
from .kg_builder import LightRAGKGBuilder
|
| 10 |
from .llm.openai_client import OpenAIClient
|
| 11 |
from .llm.topk_token_model import TopkTokenModel
|
| 12 |
from .partitioner import (
|
|
|
|
| 13 |
BFSPartitioner,
|
| 14 |
DFSPartitioner,
|
| 15 |
ECEPartitioner,
|
|
|
|
| 6 |
MultiHopGenerator,
|
| 7 |
VQAGenerator,
|
| 8 |
)
|
| 9 |
+
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
|
| 10 |
from .llm.openai_client import OpenAIClient
|
| 11 |
from .llm.topk_token_model import TopkTokenModel
|
| 12 |
from .partitioner import (
|
| 13 |
+
AnchorBFSPartitioner,
|
| 14 |
BFSPartitioner,
|
| 15 |
DFSPartitioner,
|
| 16 |
ECEPartitioner,
|
graphgen/models/generator/aggregated_generator.py
CHANGED
|
@@ -53,7 +53,7 @@ class AggregatedGenerator(BaseGenerator):
|
|
| 53 |
# ]
|
| 54 |
# )
|
| 55 |
prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
|
| 56 |
-
|
| 57 |
)
|
| 58 |
return prompt
|
| 59 |
|
|
@@ -115,8 +115,8 @@ class AggregatedGenerator(BaseGenerator):
|
|
| 115 |
question_generation_prompt = self._build_prompt_for_question_generation(context)
|
| 116 |
response = await self.llm_client.generate_answer(question_generation_prompt)
|
| 117 |
question = self.parse_response(response)["question"]
|
| 118 |
-
logger.
|
| 119 |
-
logger.
|
| 120 |
qa_pairs = {
|
| 121 |
compute_content_hash(question): {
|
| 122 |
"question": question,
|
|
|
|
| 53 |
# ]
|
| 54 |
# )
|
| 55 |
prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
|
| 56 |
+
entities=entities_str, relationships=relations_str
|
| 57 |
)
|
| 58 |
return prompt
|
| 59 |
|
|
|
|
| 115 |
question_generation_prompt = self._build_prompt_for_question_generation(context)
|
| 116 |
response = await self.llm_client.generate_answer(question_generation_prompt)
|
| 117 |
question = self.parse_response(response)["question"]
|
| 118 |
+
logger.debug("Question: %s", question)
|
| 119 |
+
logger.debug("Answer: %s", context)
|
| 120 |
qa_pairs = {
|
| 121 |
compute_content_hash(question): {
|
| 122 |
"question": question,
|
graphgen/models/generator/atomic_generator.py
CHANGED
|
@@ -42,8 +42,8 @@ class AtomicGenerator(BaseGenerator):
|
|
| 42 |
return {}
|
| 43 |
question = question.strip('"')
|
| 44 |
answer = answer.strip('"')
|
| 45 |
-
logger.
|
| 46 |
-
logger.
|
| 47 |
return {
|
| 48 |
compute_content_hash(question): {
|
| 49 |
"question": question,
|
|
|
|
| 42 |
return {}
|
| 43 |
question = question.strip('"')
|
| 44 |
answer = answer.strip('"')
|
| 45 |
+
logger.debug("Question: %s", question)
|
| 46 |
+
logger.debug("Answer: %s", answer)
|
| 47 |
return {
|
| 48 |
compute_content_hash(question): {
|
| 49 |
"question": question,
|
graphgen/models/generator/cot_generator.py
CHANGED
|
@@ -85,8 +85,8 @@ class CoTGenerator(BaseGenerator):
|
|
| 85 |
|
| 86 |
question = question.strip('"')
|
| 87 |
reasoning_path = reasoning_path.strip('"')
|
| 88 |
-
logger.
|
| 89 |
-
logger.
|
| 90 |
return {
|
| 91 |
"question": question,
|
| 92 |
"reasoning_path": reasoning_path,
|
|
@@ -110,7 +110,7 @@ class CoTGenerator(BaseGenerator):
|
|
| 110 |
question, reasoning_path = response["question"], response["reasoning_path"]
|
| 111 |
prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
|
| 112 |
cot_answer = await self.llm_client.generate_answer(prompt)
|
| 113 |
-
logger.
|
| 114 |
qa_pairs = {
|
| 115 |
compute_content_hash(question): {
|
| 116 |
"question": question,
|
|
|
|
| 85 |
|
| 86 |
question = question.strip('"')
|
| 87 |
reasoning_path = reasoning_path.strip('"')
|
| 88 |
+
logger.debug("CoT Question: %s", question)
|
| 89 |
+
logger.debug("CoT Reasoning Path: %s", reasoning_path)
|
| 90 |
return {
|
| 91 |
"question": question,
|
| 92 |
"reasoning_path": reasoning_path,
|
|
|
|
| 110 |
question, reasoning_path = response["question"], response["reasoning_path"]
|
| 111 |
prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
|
| 112 |
cot_answer = await self.llm_client.generate_answer(prompt)
|
| 113 |
+
logger.debug("CoT Answer: %s", cot_answer)
|
| 114 |
qa_pairs = {
|
| 115 |
compute_content_hash(question): {
|
| 116 |
"question": question,
|
graphgen/models/generator/multi_hop_generator.py
CHANGED
|
@@ -45,8 +45,8 @@ class MultiHopGenerator(BaseGenerator):
|
|
| 45 |
return {}
|
| 46 |
question = question.strip('"')
|
| 47 |
answer = answer.strip('"')
|
| 48 |
-
logger.
|
| 49 |
-
logger.
|
| 50 |
return {
|
| 51 |
compute_content_hash(question): {
|
| 52 |
"question": question,
|
|
|
|
| 45 |
return {}
|
| 46 |
question = question.strip('"')
|
| 47 |
answer = answer.strip('"')
|
| 48 |
+
logger.debug("Question: %s", question)
|
| 49 |
+
logger.debug("Answer: %s", answer)
|
| 50 |
return {
|
| 51 |
compute_content_hash(question): {
|
| 52 |
"question": question,
|
graphgen/models/generator/vqa_generator.py
CHANGED
|
@@ -2,6 +2,8 @@ from dataclasses import dataclass
|
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
@dataclass
|
|
@@ -10,14 +12,127 @@ class VQAGenerator(BaseGenerator):
|
|
| 10 |
def build_prompt(
|
| 11 |
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
|
| 12 |
) -> str:
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
)
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
@staticmethod
|
| 19 |
def parse_response(response: str) -> Any:
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGenerator
|
| 5 |
+
from graphgen.templates import VQA_GENERATION_PROMPT
|
| 6 |
+
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 7 |
|
| 8 |
|
| 9 |
@dataclass
|
|
|
|
| 12 |
def build_prompt(
|
| 13 |
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
|
| 14 |
) -> str:
|
| 15 |
+
nodes, edges = batch
|
| 16 |
+
entities_str = "\n".join(
|
| 17 |
+
[
|
| 18 |
+
f"{index + 1}. {node[0]}: {node[1]['description']}"
|
| 19 |
+
for index, node in enumerate(nodes)
|
| 20 |
+
]
|
| 21 |
)
|
| 22 |
|
| 23 |
+
relationships_str = "\n".join(
|
| 24 |
+
[
|
| 25 |
+
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
|
| 26 |
+
for index, edge in enumerate(edges)
|
| 27 |
+
]
|
| 28 |
+
)
|
| 29 |
+
language = detect_main_language(entities_str + relationships_str)
|
| 30 |
+
prompt = VQA_GENERATION_PROMPT[language].format(
|
| 31 |
+
entities=entities_str, relationships=relationships_str
|
| 32 |
+
)
|
| 33 |
+
return prompt
|
| 34 |
+
|
| 35 |
@staticmethod
|
| 36 |
def parse_response(response: str) -> Any:
|
| 37 |
+
"""
|
| 38 |
+
Parse the LLM response and return the generated QAs
|
| 39 |
+
:param response
|
| 40 |
+
:return: QA pairs
|
| 41 |
+
"""
|
| 42 |
+
qa_pairs = {}
|
| 43 |
+
qa_list = response.strip().split("\n\n")
|
| 44 |
+
for qa in qa_list:
|
| 45 |
+
if "Question:" in qa and "Answer:" in qa:
|
| 46 |
+
question = qa.split("Question:")[1].split("Answer:")[0].strip()
|
| 47 |
+
answer = qa.split("Answer:")[1].strip()
|
| 48 |
+
elif "问题:" in qa and "答案:" in qa:
|
| 49 |
+
question = qa.split("问题:")[1].split("答案:")[0].strip()
|
| 50 |
+
answer = qa.split("答案:")[1].strip()
|
| 51 |
+
else:
|
| 52 |
+
logger.error("Failed to parse QA pair: %s", qa)
|
| 53 |
+
continue
|
| 54 |
+
question = question.strip('"')
|
| 55 |
+
answer = answer.strip('"')
|
| 56 |
+
logger.debug("Question: %s", question)
|
| 57 |
+
logger.debug("Answer: %s", answer)
|
| 58 |
+
qa_pairs[compute_content_hash(question)] = {
|
| 59 |
+
"question": question,
|
| 60 |
+
"answer": answer,
|
| 61 |
+
}
|
| 62 |
+
return qa_pairs
|
| 63 |
+
|
| 64 |
+
async def generate(
|
| 65 |
+
self,
|
| 66 |
+
batch: tuple[
|
| 67 |
+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 68 |
+
],
|
| 69 |
+
) -> dict[str, Any]:
|
| 70 |
+
"""
|
| 71 |
+
Generate QAs based on a given batch.
|
| 72 |
+
:param batch
|
| 73 |
+
:return: QA pairs
|
| 74 |
+
"""
|
| 75 |
+
result = {}
|
| 76 |
+
prompt = self.build_prompt(batch)
|
| 77 |
+
response = await self.llm_client.generate_answer(prompt)
|
| 78 |
+
qa_pairs = self.parse_response(response) # generate one or more QA pairs
|
| 79 |
+
nodes, _ = batch
|
| 80 |
+
for node in nodes:
|
| 81 |
+
node_data = node[1]
|
| 82 |
+
if "images" in node_data and node_data["images"]:
|
| 83 |
+
img_path = node_data["images"]["img_path"]
|
| 84 |
+
for qa in qa_pairs.values():
|
| 85 |
+
qa["img_path"] = img_path
|
| 86 |
+
result.update(qa_pairs)
|
| 87 |
+
return result
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def format_generation_results(
|
| 91 |
+
results: list[dict], output_data_format: str
|
| 92 |
+
) -> list[dict[str, Any]]:
|
| 93 |
+
if output_data_format == "Alpaca":
|
| 94 |
+
results = [
|
| 95 |
+
{
|
| 96 |
+
"instruction": v["question"],
|
| 97 |
+
"input": "",
|
| 98 |
+
"output": v["answer"],
|
| 99 |
+
"image": v.get("img_path", ""),
|
| 100 |
+
}
|
| 101 |
+
for item in results
|
| 102 |
+
for k, v in item.items()
|
| 103 |
+
]
|
| 104 |
+
elif output_data_format == "Sharegpt":
|
| 105 |
+
results = [
|
| 106 |
+
{
|
| 107 |
+
"conversations": [
|
| 108 |
+
{
|
| 109 |
+
"from": "human",
|
| 110 |
+
"value": [
|
| 111 |
+
{"text": v["question"], "image": v.get("img_path", "")}
|
| 112 |
+
],
|
| 113 |
+
},
|
| 114 |
+
{"from": "gpt", "value": v["answer"]},
|
| 115 |
+
]
|
| 116 |
+
}
|
| 117 |
+
for item in results
|
| 118 |
+
for k, v in item.items()
|
| 119 |
+
]
|
| 120 |
+
elif output_data_format == "ChatML":
|
| 121 |
+
results = [
|
| 122 |
+
{
|
| 123 |
+
"messages": [
|
| 124 |
+
{
|
| 125 |
+
"role": "user",
|
| 126 |
+
"content": [
|
| 127 |
+
{"text": v["question"], "image": v.get("img_path", "")}
|
| 128 |
+
],
|
| 129 |
+
},
|
| 130 |
+
{"role": "assistant", "content": v["answer"]},
|
| 131 |
+
]
|
| 132 |
+
}
|
| 133 |
+
for item in results
|
| 134 |
+
for k, v in item.items()
|
| 135 |
+
]
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(f"Unknown output data format: {output_data_format}")
|
| 138 |
+
return results
|
graphgen/models/kg_builder/__init__.py
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
from .light_rag_kg_builder import LightRAGKGBuilder
|
|
|
|
|
|
| 1 |
from .light_rag_kg_builder import LightRAGKGBuilder
|
| 2 |
+
from .mm_kg_builder import MMKGBuilder
|
graphgen/models/kg_builder/light_rag_kg_builder.py
CHANGED
|
@@ -6,7 +6,6 @@ from typing import Dict, List, Tuple
|
|
| 6 |
from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
|
| 7 |
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
|
| 8 |
from graphgen.utils import (
|
| 9 |
-
detect_if_chinese,
|
| 10 |
detect_main_language,
|
| 11 |
handle_single_entity_extraction,
|
| 12 |
handle_single_relationship_extraction,
|
|
@@ -33,8 +32,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
|
|
| 33 |
content = chunk.content
|
| 34 |
|
| 35 |
# step 1: language_detection
|
| 36 |
-
language =
|
| 37 |
-
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
|
| 38 |
|
| 39 |
hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
|
| 40 |
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
|
|
@@ -42,7 +40,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
|
|
| 42 |
|
| 43 |
# step 2: initial glean
|
| 44 |
final_result = await self.llm_client.generate_answer(hint_prompt)
|
| 45 |
-
logger.
|
| 46 |
|
| 47 |
# step3: iterative refinement
|
| 48 |
history = pack_history_conversations(hint_prompt, final_result)
|
|
@@ -57,7 +55,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
|
|
| 57 |
glean_result = await self.llm_client.generate_answer(
|
| 58 |
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
|
| 59 |
)
|
| 60 |
-
logger.
|
| 61 |
|
| 62 |
history += pack_history_conversations(
|
| 63 |
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
|
|
@@ -201,11 +199,6 @@ class LightRAGKGBuilder(BaseKGBuilder):
|
|
| 201 |
|
| 202 |
tokenizer_instance = self.llm_client.tokenizer
|
| 203 |
language = detect_main_language(description)
|
| 204 |
-
if language == "en":
|
| 205 |
-
language = "English"
|
| 206 |
-
else:
|
| 207 |
-
language = "Chinese"
|
| 208 |
-
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
|
| 209 |
|
| 210 |
tokens = tokenizer_instance.encode(description)
|
| 211 |
if len(tokens) < max_summary_tokens:
|
|
|
|
| 6 |
from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
|
| 7 |
from graphgen.templates import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT
|
| 8 |
from graphgen.utils import (
|
|
|
|
| 9 |
detect_main_language,
|
| 10 |
handle_single_entity_extraction,
|
| 11 |
handle_single_relationship_extraction,
|
|
|
|
| 32 |
content = chunk.content
|
| 33 |
|
| 34 |
# step 1: language_detection
|
| 35 |
+
language = detect_main_language(content)
|
|
|
|
| 36 |
|
| 37 |
hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
|
| 38 |
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
|
|
|
|
| 40 |
|
| 41 |
# step 2: initial glean
|
| 42 |
final_result = await self.llm_client.generate_answer(hint_prompt)
|
| 43 |
+
logger.debug("First extraction result: %s", final_result)
|
| 44 |
|
| 45 |
# step3: iterative refinement
|
| 46 |
history = pack_history_conversations(hint_prompt, final_result)
|
|
|
|
| 55 |
glean_result = await self.llm_client.generate_answer(
|
| 56 |
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
|
| 57 |
)
|
| 58 |
+
logger.debug("Loop %s glean: %s", loop_idx + 1, glean_result)
|
| 59 |
|
| 60 |
history += pack_history_conversations(
|
| 61 |
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
|
|
|
|
| 199 |
|
| 200 |
tokenizer_instance = self.llm_client.tokenizer
|
| 201 |
language = detect_main_language(description)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
tokens = tokenizer_instance.encode(description)
|
| 204 |
if len(tokens) < max_summary_tokens:
|
graphgen/models/kg_builder/mm_kg_builder.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from typing import Dict, List, Tuple
|
| 4 |
+
|
| 5 |
+
from graphgen.bases import BaseLLMClient, Chunk
|
| 6 |
+
from graphgen.templates import MMKG_EXTRACTION_PROMPT
|
| 7 |
+
from graphgen.utils import (
|
| 8 |
+
detect_main_language,
|
| 9 |
+
handle_single_entity_extraction,
|
| 10 |
+
handle_single_relationship_extraction,
|
| 11 |
+
logger,
|
| 12 |
+
split_string_by_multi_markers,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
from .light_rag_kg_builder import LightRAGKGBuilder
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MMKGBuilder(LightRAGKGBuilder):
|
| 19 |
+
llm_client: BaseLLMClient = None
|
| 20 |
+
|
| 21 |
+
async def extract(
|
| 22 |
+
self, chunk: Chunk
|
| 23 |
+
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
|
| 24 |
+
"""
|
| 25 |
+
Extract entities and relationships from a single multi-modal chunk using the LLM client.
|
| 26 |
+
Expect to get a mini graph which contains a central multi-modal entity
|
| 27 |
+
and its related text entities and relationships.
|
| 28 |
+
Like:
|
| 29 |
+
(image: "image_of_eiffel_tower") --[located_in]--> (text: "Paris")
|
| 30 |
+
(image: "image_of_eiffel_tower") --[built_in]--> (text: "1889")
|
| 31 |
+
(text: "Eiffel Tower") --[height]--> (text: "324 meters")
|
| 32 |
+
:param chunk
|
| 33 |
+
"""
|
| 34 |
+
chunk_id = chunk.id
|
| 35 |
+
chunk_type = chunk.type # image | table | formula | ...
|
| 36 |
+
metadata = chunk.metadata
|
| 37 |
+
|
| 38 |
+
# choose different extraction strategies based on chunk type
|
| 39 |
+
if chunk_type == "image":
|
| 40 |
+
image_caption = "\n".join(metadata.get("image_caption", ""))
|
| 41 |
+
language = detect_main_language(image_caption)
|
| 42 |
+
prompt_template = MMKG_EXTRACTION_PROMPT[language].format(
|
| 43 |
+
**MMKG_EXTRACTION_PROMPT["FORMAT"],
|
| 44 |
+
chunk_type=chunk_type,
|
| 45 |
+
chunk_id=chunk_id,
|
| 46 |
+
chunk_text=image_caption,
|
| 47 |
+
)
|
| 48 |
+
result = await self.llm_client.generate_answer(prompt_template)
|
| 49 |
+
logger.debug("Image chunk extraction result: %s", result)
|
| 50 |
+
|
| 51 |
+
# parse the result
|
| 52 |
+
records = split_string_by_multi_markers(
|
| 53 |
+
result,
|
| 54 |
+
[
|
| 55 |
+
MMKG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
|
| 56 |
+
MMKG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
|
| 57 |
+
],
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
nodes = defaultdict(list)
|
| 61 |
+
edges = defaultdict(list)
|
| 62 |
+
|
| 63 |
+
for record in records:
|
| 64 |
+
match = re.search(r"\((.*)\)", record)
|
| 65 |
+
if not match:
|
| 66 |
+
continue
|
| 67 |
+
inner = match.group(1)
|
| 68 |
+
|
| 69 |
+
attributes = split_string_by_multi_markers(
|
| 70 |
+
inner, [MMKG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
entity = await handle_single_entity_extraction(attributes, chunk_id)
|
| 74 |
+
if entity is not None:
|
| 75 |
+
nodes[entity["entity_name"]].append(entity)
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
relation = await handle_single_relationship_extraction(
|
| 79 |
+
attributes, chunk_id
|
| 80 |
+
)
|
| 81 |
+
if relation is not None:
|
| 82 |
+
key = (relation["src_id"], relation["tgt_id"])
|
| 83 |
+
edges[key].append(relation)
|
| 84 |
+
|
| 85 |
+
return dict(nodes), dict(edges)
|
| 86 |
+
|
| 87 |
+
if chunk_type == "table":
|
| 88 |
+
pass # TODO: implement table-based entity and relationship extraction
|
| 89 |
+
if chunk_type == "formula":
|
| 90 |
+
pass # TODO: implement formula-based entity and relationship extraction
|
| 91 |
+
|
| 92 |
+
logger.error("Unsupported chunk type for MMKGBuilder: %s", chunk_type)
|
| 93 |
+
return defaultdict(list), defaultdict(list)
|
graphgen/models/partitioner/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from .bfs_partitioner import BFSPartitioner
|
| 2 |
from .dfs_partitioner import DFSPartitioner
|
| 3 |
from .ece_partitioner import ECEPartitioner
|
|
|
|
| 1 |
+
from .anchor_bfs_partitioner import AnchorBFSPartitioner
|
| 2 |
from .bfs_partitioner import BFSPartitioner
|
| 3 |
from .dfs_partitioner import DFSPartitioner
|
| 4 |
from .ece_partitioner import ECEPartitioner
|
graphgen/models/partitioner/anchor_bfs_partitioner.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from collections import deque
|
| 3 |
+
from typing import Any, List, Literal, Set, Tuple
|
| 4 |
+
|
| 5 |
+
from graphgen.bases import BaseGraphStorage
|
| 6 |
+
from graphgen.bases.datatypes import Community
|
| 7 |
+
|
| 8 |
+
from .bfs_partitioner import BFSPartitioner
|
| 9 |
+
|
| 10 |
+
NODE_UNIT: str = "n"
|
| 11 |
+
EDGE_UNIT: str = "e"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AnchorBFSPartitioner(BFSPartitioner):
|
| 15 |
+
"""
|
| 16 |
+
Anchor BFS partitioner that partitions the graph into communities of a fixed size.
|
| 17 |
+
1. Randomly choose a node of a specified type as the anchor.
|
| 18 |
+
2. Expand the community using BFS until the max unit size is reached.(A unit is a node or an edge.)
|
| 19 |
+
3. Non-anchor units can only be "pulled" into a community and never become seeds themselves.
|
| 20 |
+
For example, for VQA tasks, we may want to use image nodes as anchors and expand to nearby text nodes and edges.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
*,
|
| 26 |
+
anchor_type: Literal["image"] = "image",
|
| 27 |
+
anchor_ids: Set[str] | None = None,
|
| 28 |
+
) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.anchor_type = anchor_type
|
| 31 |
+
self.anchor_ids = anchor_ids
|
| 32 |
+
|
| 33 |
+
async def partition(
|
| 34 |
+
self,
|
| 35 |
+
g: BaseGraphStorage,
|
| 36 |
+
max_units_per_community: int = 1,
|
| 37 |
+
**kwargs: Any,
|
| 38 |
+
) -> List[Community]:
|
| 39 |
+
nodes = await g.get_all_nodes() # List[tuple[id, meta]]
|
| 40 |
+
edges = await g.get_all_edges() # List[tuple[u, v, meta]]
|
| 41 |
+
|
| 42 |
+
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 43 |
+
|
| 44 |
+
anchors: Set[str] = await self._pick_anchor_ids(nodes)
|
| 45 |
+
if not anchors:
|
| 46 |
+
return [] # if no anchors, return empty list
|
| 47 |
+
|
| 48 |
+
used_n: set[str] = set()
|
| 49 |
+
used_e: set[frozenset[str]] = set()
|
| 50 |
+
communities: List[Community] = []
|
| 51 |
+
|
| 52 |
+
seeds = list(anchors)
|
| 53 |
+
random.shuffle(seeds)
|
| 54 |
+
|
| 55 |
+
for seed_node in seeds:
|
| 56 |
+
if seed_node in used_n:
|
| 57 |
+
continue
|
| 58 |
+
comm_n, comm_e = await self._grow_community(
|
| 59 |
+
seed_node, adj, max_units_per_community, used_n, used_e
|
| 60 |
+
)
|
| 61 |
+
if comm_n or comm_e:
|
| 62 |
+
communities.append(
|
| 63 |
+
Community(id=len(communities), nodes=comm_n, edges=comm_e)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
return communities
|
| 67 |
+
|
| 68 |
+
async def _pick_anchor_ids(
|
| 69 |
+
self,
|
| 70 |
+
nodes: List[tuple[str, dict]],
|
| 71 |
+
) -> Set[str]:
|
| 72 |
+
if self.anchor_ids is not None:
|
| 73 |
+
return self.anchor_ids
|
| 74 |
+
|
| 75 |
+
anchor_ids: Set[str] = set()
|
| 76 |
+
for node_id, meta in nodes:
|
| 77 |
+
node_type = str(meta.get("entity_type", "")).lower()
|
| 78 |
+
if self.anchor_type.lower() in node_type:
|
| 79 |
+
anchor_ids.add(node_id)
|
| 80 |
+
return anchor_ids
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
async def _grow_community(
|
| 84 |
+
seed: str,
|
| 85 |
+
adj: dict[str, List[str]],
|
| 86 |
+
max_units: int,
|
| 87 |
+
used_n: set[str],
|
| 88 |
+
used_e: set[frozenset[str]],
|
| 89 |
+
) -> Tuple[List[str], List[Tuple[str, str]]]:
|
| 90 |
+
"""
|
| 91 |
+
Grow a community from the seed node using BFS.
|
| 92 |
+
:param seed: seed node id
|
| 93 |
+
:param adj: adjacency list
|
| 94 |
+
:param max_units: maximum number of units (nodes + edges) in the community
|
| 95 |
+
:param used_n: set of used node ids
|
| 96 |
+
:param used_e: set of used edge keys
|
| 97 |
+
:return: (list of node ids, list of edge tuples)
|
| 98 |
+
"""
|
| 99 |
+
comm_n: List[str] = []
|
| 100 |
+
comm_e: List[Tuple[str, str]] = []
|
| 101 |
+
queue: deque[tuple[str, Any]] = deque([(NODE_UNIT, seed)])
|
| 102 |
+
cnt = 0
|
| 103 |
+
|
| 104 |
+
while queue and cnt < max_units:
|
| 105 |
+
k, it = queue.popleft()
|
| 106 |
+
|
| 107 |
+
if k == NODE_UNIT:
|
| 108 |
+
if it in used_n:
|
| 109 |
+
continue
|
| 110 |
+
used_n.add(it)
|
| 111 |
+
comm_n.append(it)
|
| 112 |
+
cnt += 1
|
| 113 |
+
for nei in adj[it]:
|
| 114 |
+
e_key = frozenset((it, nei))
|
| 115 |
+
if e_key not in used_e:
|
| 116 |
+
queue.append((EDGE_UNIT, e_key))
|
| 117 |
+
else: # EDGE_UNIT
|
| 118 |
+
if it in used_e:
|
| 119 |
+
continue
|
| 120 |
+
used_e.add(it)
|
| 121 |
+
u, v = it
|
| 122 |
+
comm_e.append((u, v))
|
| 123 |
+
cnt += 1
|
| 124 |
+
for n in it:
|
| 125 |
+
if n not in used_n:
|
| 126 |
+
queue.append((NODE_UNIT, n))
|
| 127 |
+
|
| 128 |
+
return comm_n, comm_e
|
graphgen/models/reader/csv_reader.py
CHANGED
|
@@ -9,6 +9,9 @@ class CSVReader(BaseReader):
|
|
| 9 |
def read(self, file_path: str) -> List[Dict[str, Any]]:
|
| 10 |
|
| 11 |
df = pd.read_csv(file_path)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
def read(self, file_path: str) -> List[Dict[str, Any]]:
|
| 10 |
|
| 11 |
df = pd.read_csv(file_path)
|
| 12 |
+
for _, row in df.iterrows():
|
| 13 |
+
if "type" in row and row["type"] == "text" and self.text_column not in row:
|
| 14 |
+
raise ValueError(
|
| 15 |
+
f"Missing '{self.text_column}' in document: {row.to_dict()}"
|
| 16 |
+
)
|
| 17 |
+
return self.filter(df.to_dict(orient="records"))
|
graphgen/models/reader/json_reader.py
CHANGED
|
@@ -10,9 +10,9 @@ class JSONReader(BaseReader):
|
|
| 10 |
data = json.load(f)
|
| 11 |
if isinstance(data, list):
|
| 12 |
for doc in data:
|
| 13 |
-
if self.text_column not in doc:
|
| 14 |
raise ValueError(
|
| 15 |
f"Missing '{self.text_column}' in document: {doc}"
|
| 16 |
)
|
| 17 |
-
return data
|
| 18 |
raise ValueError("JSON file must contain a list of documents.")
|
|
|
|
| 10 |
data = json.load(f)
|
| 11 |
if isinstance(data, list):
|
| 12 |
for doc in data:
|
| 13 |
+
if doc.get("type") == "text" and self.text_column not in doc:
|
| 14 |
raise ValueError(
|
| 15 |
f"Missing '{self.text_column}' in document: {doc}"
|
| 16 |
)
|
| 17 |
+
return self.filter(data)
|
| 18 |
raise ValueError("JSON file must contain a list of documents.")
|
graphgen/models/reader/jsonl_reader.py
CHANGED
|
@@ -12,12 +12,11 @@ class JSONLReader(BaseReader):
|
|
| 12 |
for line in f:
|
| 13 |
try:
|
| 14 |
doc = json.loads(line)
|
| 15 |
-
if self.text_column in doc:
|
| 16 |
-
docs.append(doc)
|
| 17 |
-
else:
|
| 18 |
raise ValueError(
|
| 19 |
f"Missing '{self.text_column}' in document: {doc}"
|
| 20 |
)
|
|
|
|
| 21 |
except json.JSONDecodeError as e:
|
| 22 |
logger.error("Error decoding JSON line: %s. Error: %s", line, e)
|
| 23 |
-
return docs
|
|
|
|
| 12 |
for line in f:
|
| 13 |
try:
|
| 14 |
doc = json.loads(line)
|
| 15 |
+
if doc.get("type") == "text" and self.text_column not in doc:
|
|
|
|
|
|
|
| 16 |
raise ValueError(
|
| 17 |
f"Missing '{self.text_column}' in document: {doc}"
|
| 18 |
)
|
| 19 |
+
docs.append(doc)
|
| 20 |
except json.JSONDecodeError as e:
|
| 21 |
logger.error("Error decoding JSON line: %s. Error: %s", line, e)
|
| 22 |
+
return self.filter(docs)
|
graphgen/models/reader/pdf_reader.py
CHANGED
|
@@ -74,7 +74,7 @@ class PDFReader(BaseReader):
|
|
| 74 |
kwargs = {**self._default_kwargs, **override}
|
| 75 |
|
| 76 |
mineru_result = self._call_mineru(pdf_path, kwargs)
|
| 77 |
-
return mineru_result
|
| 78 |
|
| 79 |
def _call_mineru(
|
| 80 |
self, pdf_path: Path, kwargs: Dict[str, Any]
|
|
@@ -172,8 +172,6 @@ class MinerUParser:
|
|
| 172 |
for key in ("page_idx", "bbox", "text_level"):
|
| 173 |
if item.get(key) is not None:
|
| 174 |
del item[key]
|
| 175 |
-
if item["type"] == "text" and not item["content"].strip():
|
| 176 |
-
continue
|
| 177 |
results.append(item)
|
| 178 |
return results
|
| 179 |
|
|
|
|
| 74 |
kwargs = {**self._default_kwargs, **override}
|
| 75 |
|
| 76 |
mineru_result = self._call_mineru(pdf_path, kwargs)
|
| 77 |
+
return self.filter(mineru_result)
|
| 78 |
|
| 79 |
def _call_mineru(
|
| 80 |
self, pdf_path: Path, kwargs: Dict[str, Any]
|
|
|
|
| 172 |
for key in ("page_idx", "bbox", "text_level"):
|
| 173 |
if item.get(key) is not None:
|
| 174 |
del item[key]
|
|
|
|
|
|
|
| 175 |
results.append(item)
|
| 176 |
return results
|
| 177 |
|
graphgen/models/reader/txt_reader.py
CHANGED
|
@@ -11,4 +11,4 @@ class TXTReader(BaseReader):
|
|
| 11 |
line = line.strip()
|
| 12 |
if line:
|
| 13 |
docs.append({self.text_column: line})
|
| 14 |
-
return docs
|
|
|
|
| 11 |
line = line.strip()
|
| 12 |
if line:
|
| 13 |
docs.append({self.text_column: line})
|
| 14 |
+
return self.filter(docs)
|
graphgen/operators/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from .build_kg import
|
| 2 |
from .generate import generate_qas
|
| 3 |
from .judge import judge_statement
|
| 4 |
from .partition import partition_kg
|
|
|
|
| 1 |
+
from .build_kg import build_mm_kg, build_text_kg
|
| 2 |
from .generate import generate_qas
|
| 3 |
from .judge import judge_statement
|
| 4 |
from .partition import partition_kg
|
graphgen/operators/build_kg/__init__.py
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
-
from .
|
|
|
|
|
|
| 1 |
+
from .build_mm_kg import build_mm_kg
|
| 2 |
+
from .build_text_kg import build_text_kg
|
graphgen/operators/build_kg/build_mm_kg.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
from graphgen.bases.base_storage import BaseGraphStorage
|
| 7 |
+
from graphgen.bases.datatypes import Chunk
|
| 8 |
+
from graphgen.models import MMKGBuilder, OpenAIClient
|
| 9 |
+
from graphgen.utils import run_concurrent
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
async def build_mm_kg(
|
| 13 |
+
llm_client: OpenAIClient,
|
| 14 |
+
kg_instance: BaseGraphStorage,
|
| 15 |
+
chunks: List[Chunk],
|
| 16 |
+
progress_bar: gr.Progress = None,
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
Build multi-modal KG and merge into kg_instance
|
| 20 |
+
:param llm_client: Synthesizer LLM model to extract entities and relationships
|
| 21 |
+
:param kg_instance
|
| 22 |
+
:param chunks
|
| 23 |
+
:param progress_bar: Gradio progress bar to show the progress of the extraction
|
| 24 |
+
:return:
|
| 25 |
+
"""
|
| 26 |
+
mm_builder = MMKGBuilder(llm_client=llm_client)
|
| 27 |
+
|
| 28 |
+
results = await run_concurrent(
|
| 29 |
+
mm_builder.extract,
|
| 30 |
+
chunks,
|
| 31 |
+
desc="[2/4] Extracting entities and relationships from multi-modal chunks",
|
| 32 |
+
unit="chunk",
|
| 33 |
+
progress_bar=progress_bar,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
nodes = defaultdict(list)
|
| 37 |
+
edges = defaultdict(list)
|
| 38 |
+
for n, e in results:
|
| 39 |
+
for k, v in n.items():
|
| 40 |
+
nodes[k].extend(v)
|
| 41 |
+
for k, v in e.items():
|
| 42 |
+
edges[tuple(sorted(k))].extend(v)
|
| 43 |
+
|
| 44 |
+
await run_concurrent(
|
| 45 |
+
lambda kv: mm_builder.merge_nodes(kv, kg_instance=kg_instance),
|
| 46 |
+
list(nodes.items()),
|
| 47 |
+
desc="Inserting entities into storage",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
await run_concurrent(
|
| 51 |
+
lambda kv: mm_builder.merge_edges(kv, kg_instance=kg_instance),
|
| 52 |
+
list(edges.items()),
|
| 53 |
+
desc="Inserting relationships into storage",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
return kg_instance
|
graphgen/operators/build_kg/{build_kg.py → build_text_kg.py}
RENAMED
|
@@ -9,7 +9,7 @@ from graphgen.models import LightRAGKGBuilder, OpenAIClient
|
|
| 9 |
from graphgen.utils import run_concurrent
|
| 10 |
|
| 11 |
|
| 12 |
-
async def
|
| 13 |
llm_client: OpenAIClient,
|
| 14 |
kg_instance: BaseGraphStorage,
|
| 15 |
chunks: List[Chunk],
|
|
|
|
| 9 |
from graphgen.utils import run_concurrent
|
| 10 |
|
| 11 |
|
| 12 |
+
async def build_text_kg(
|
| 13 |
llm_client: OpenAIClient,
|
| 14 |
kg_instance: BaseGraphStorage,
|
| 15 |
chunks: List[Chunk],
|
graphgen/operators/judge.py
CHANGED
|
@@ -37,7 +37,7 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
| 37 |
edge_data = edge[2]
|
| 38 |
|
| 39 |
if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
|
| 40 |
-
logger.
|
| 41 |
"Edge %s -> %s already judged, loss: %s, skip",
|
| 42 |
source_id,
|
| 43 |
target_id,
|
|
@@ -63,7 +63,7 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
| 63 |
|
| 64 |
loss = yes_no_loss_entropy(judgements, gts)
|
| 65 |
|
| 66 |
-
logger.
|
| 67 |
"Edge %s -> %s description: %s loss: %s",
|
| 68 |
source_id,
|
| 69 |
target_id,
|
|
@@ -100,7 +100,7 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
| 100 |
node_data = node[1]
|
| 101 |
|
| 102 |
if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
|
| 103 |
-
logger.
|
| 104 |
"Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
|
| 105 |
)
|
| 106 |
return node_id, node_data
|
|
@@ -123,14 +123,14 @@ async def judge_statement( # pylint: disable=too-many-statements
|
|
| 123 |
|
| 124 |
loss = yes_no_loss_entropy(judgements, gts)
|
| 125 |
|
| 126 |
-
logger.
|
| 127 |
"Node %s description: %s loss: %s", node_id, description, loss
|
| 128 |
)
|
| 129 |
|
| 130 |
node_data["loss"] = loss
|
| 131 |
except Exception as e: # pylint: disable=broad-except
|
| 132 |
logger.error("Error in judging entity %s: %s", node_id, e)
|
| 133 |
-
logger.
|
| 134 |
node_data["loss"] = -math.log(0.1)
|
| 135 |
|
| 136 |
await graph_storage.update_node(node_id, node_data)
|
|
|
|
| 37 |
edge_data = edge[2]
|
| 38 |
|
| 39 |
if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
|
| 40 |
+
logger.debug(
|
| 41 |
"Edge %s -> %s already judged, loss: %s, skip",
|
| 42 |
source_id,
|
| 43 |
target_id,
|
|
|
|
| 63 |
|
| 64 |
loss = yes_no_loss_entropy(judgements, gts)
|
| 65 |
|
| 66 |
+
logger.debug(
|
| 67 |
"Edge %s -> %s description: %s loss: %s",
|
| 68 |
source_id,
|
| 69 |
target_id,
|
|
|
|
| 100 |
node_data = node[1]
|
| 101 |
|
| 102 |
if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
|
| 103 |
+
logger.debug(
|
| 104 |
"Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
|
| 105 |
)
|
| 106 |
return node_id, node_data
|
|
|
|
| 123 |
|
| 124 |
loss = yes_no_loss_entropy(judgements, gts)
|
| 125 |
|
| 126 |
+
logger.debug(
|
| 127 |
"Node %s description: %s loss: %s", node_id, description, loss
|
| 128 |
)
|
| 129 |
|
| 130 |
node_data["loss"] = loss
|
| 131 |
except Exception as e: # pylint: disable=broad-except
|
| 132 |
logger.error("Error in judging entity %s: %s", node_id, e)
|
| 133 |
+
logger.error("Use default loss 0.1")
|
| 134 |
node_data["loss"] = -math.log(0.1)
|
| 135 |
|
| 136 |
await graph_storage.update_node(node_id, node_data)
|
graphgen/operators/partition/partition_kg.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
from typing import Any
|
| 2 |
|
| 3 |
-
from graphgen.bases import BaseGraphStorage, BaseTokenizer
|
| 4 |
from graphgen.models import (
|
|
|
|
| 5 |
BFSPartitioner,
|
| 6 |
DFSPartitioner,
|
| 7 |
ECEPartitioner,
|
|
@@ -14,6 +15,7 @@ from .pre_tokenize import pre_tokenize
|
|
| 14 |
|
| 15 |
async def partition_kg(
|
| 16 |
kg_instance: BaseGraphStorage,
|
|
|
|
| 17 |
tokenizer: Any = BaseTokenizer,
|
| 18 |
partition_config: dict = None,
|
| 19 |
) -> list[
|
|
@@ -39,10 +41,28 @@ async def partition_kg(
|
|
| 39 |
elif method == "leiden":
|
| 40 |
logger.info("Partitioning knowledge graph using Leiden method.")
|
| 41 |
partitioner = LeidenPartitioner()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
else:
|
| 43 |
raise ValueError(f"Unsupported partition method: {method}")
|
| 44 |
|
| 45 |
communities = await partitioner.partition(g=kg_instance, **method_params)
|
| 46 |
logger.info("Partitioned the graph into %d communities.", len(communities))
|
| 47 |
batches = await partitioner.community2batch(communities, g=kg_instance)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
return batches
|
|
|
|
| 1 |
from typing import Any
|
| 2 |
|
| 3 |
+
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseTokenizer
|
| 4 |
from graphgen.models import (
|
| 5 |
+
AnchorBFSPartitioner,
|
| 6 |
BFSPartitioner,
|
| 7 |
DFSPartitioner,
|
| 8 |
ECEPartitioner,
|
|
|
|
| 15 |
|
| 16 |
async def partition_kg(
|
| 17 |
kg_instance: BaseGraphStorage,
|
| 18 |
+
chunk_storage: BaseKVStorage,
|
| 19 |
tokenizer: Any = BaseTokenizer,
|
| 20 |
partition_config: dict = None,
|
| 21 |
) -> list[
|
|
|
|
| 41 |
elif method == "leiden":
|
| 42 |
logger.info("Partitioning knowledge graph using Leiden method.")
|
| 43 |
partitioner = LeidenPartitioner()
|
| 44 |
+
elif method == "anchor_bfs":
|
| 45 |
+
logger.info("Partitioning knowledge graph using Anchor BFS method.")
|
| 46 |
+
partitioner = AnchorBFSPartitioner(
|
| 47 |
+
anchor_type=method_params.get("anchor_type"),
|
| 48 |
+
anchor_ids=set(method_params.get("anchor_ids", []))
|
| 49 |
+
if method_params.get("anchor_ids")
|
| 50 |
+
else None,
|
| 51 |
+
)
|
| 52 |
else:
|
| 53 |
raise ValueError(f"Unsupported partition method: {method}")
|
| 54 |
|
| 55 |
communities = await partitioner.partition(g=kg_instance, **method_params)
|
| 56 |
logger.info("Partitioned the graph into %d communities.", len(communities))
|
| 57 |
batches = await partitioner.community2batch(communities, g=kg_instance)
|
| 58 |
+
|
| 59 |
+
for _, batch in enumerate(batches):
|
| 60 |
+
nodes, edges = batch
|
| 61 |
+
for node_id, node_data in nodes:
|
| 62 |
+
entity_type = node_data.get("entity_type")
|
| 63 |
+
if entity_type and "image" in entity_type.lower():
|
| 64 |
+
node_id = node_id.strip('"').lower()
|
| 65 |
+
image_data = await chunk_storage.get_by_id(node_id)
|
| 66 |
+
if image_data:
|
| 67 |
+
node_data["images"] = image_data
|
| 68 |
return batches
|
graphgen/operators/split/split_chunks.py
CHANGED
|
@@ -48,25 +48,31 @@ async def chunk_documents(
|
|
| 48 |
async for doc_key, doc in tqdm_async(
|
| 49 |
new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
|
| 50 |
):
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
doc["content"]
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
| 67 |
}
|
| 68 |
-
|
| 69 |
-
|
|
|
|
| 70 |
inserting_chunks.update(chunks)
|
| 71 |
|
| 72 |
if progress_bar is not None:
|
|
|
|
| 48 |
async for doc_key, doc in tqdm_async(
|
| 49 |
new_docs.items(), desc="[1/4]Chunking documents", unit="doc"
|
| 50 |
):
|
| 51 |
+
doc_type = doc.get("type")
|
| 52 |
+
if doc_type == "text":
|
| 53 |
+
doc_language = detect_main_language(doc["content"])
|
| 54 |
+
text_chunks = split_chunks(
|
| 55 |
+
doc["content"],
|
| 56 |
+
language=doc_language,
|
| 57 |
+
chunk_size=chunk_size,
|
| 58 |
+
chunk_overlap=chunk_overlap,
|
| 59 |
+
)
|
| 60 |
|
| 61 |
+
chunks = {
|
| 62 |
+
compute_content_hash(txt, prefix="chunk-"): {
|
| 63 |
+
"content": txt,
|
| 64 |
+
"type": "text",
|
| 65 |
+
"full_doc_id": doc_key,
|
| 66 |
+
"length": len(tokenizer_instance.encode(txt))
|
| 67 |
+
if tokenizer_instance
|
| 68 |
+
else len(txt),
|
| 69 |
+
"language": doc_language,
|
| 70 |
+
}
|
| 71 |
+
for txt in text_chunks
|
| 72 |
}
|
| 73 |
+
else:
|
| 74 |
+
chunks = {doc_key.replace("doc-", f"{doc_type}-"): {**doc}}
|
| 75 |
+
|
| 76 |
inserting_chunks.update(chunks)
|
| 77 |
|
| 78 |
if progress_bar is not None:
|
graphgen/templates/__init__.py
CHANGED
|
@@ -5,9 +5,9 @@ from .generation import (
|
|
| 5 |
ATOMIC_GENERATION_PROMPT,
|
| 6 |
COT_GENERATION_PROMPT,
|
| 7 |
MULTI_HOP_GENERATION_PROMPT,
|
|
|
|
| 8 |
)
|
| 9 |
-
from .
|
| 10 |
-
from .kg_summarization import KG_SUMMARIZATION_PROMPT
|
| 11 |
from .question_generation import QUESTION_GENERATION_PROMPT
|
| 12 |
from .search_judgement import SEARCH_JUDGEMENT_PROMPT
|
| 13 |
from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT
|
|
|
|
| 5 |
ATOMIC_GENERATION_PROMPT,
|
| 6 |
COT_GENERATION_PROMPT,
|
| 7 |
MULTI_HOP_GENERATION_PROMPT,
|
| 8 |
+
VQA_GENERATION_PROMPT,
|
| 9 |
)
|
| 10 |
+
from .kg import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT, MMKG_EXTRACTION_PROMPT
|
|
|
|
| 11 |
from .question_generation import QUESTION_GENERATION_PROMPT
|
| 12 |
from .search_judgement import SEARCH_JUDGEMENT_PROMPT
|
| 13 |
from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT
|
graphgen/templates/generation/__init__.py
CHANGED
|
@@ -2,3 +2,4 @@ from .aggregated_generation import AGGREGATED_GENERATION_PROMPT
|
|
| 2 |
from .atomic_generation import ATOMIC_GENERATION_PROMPT
|
| 3 |
from .cot_generation import COT_GENERATION_PROMPT
|
| 4 |
from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT
|
|
|
|
|
|
| 2 |
from .atomic_generation import ATOMIC_GENERATION_PROMPT
|
| 3 |
from .cot_generation import COT_GENERATION_PROMPT
|
| 4 |
from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT
|
| 5 |
+
from .vqa_generation import VQA_GENERATION_PROMPT
|
graphgen/templates/generation/aggregated_generation.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
# pylint: disable=C0301
|
| 2 |
ANSWER_REPHRASING_CONTEXT_EN: str = """---Role---
|
| 3 |
You are an NLP expert responsible for generating a logically structured and coherent rephrased version of the TEXT based on ENTITIES and RELATIONSHIPS provided below. You may refer to the original text to assist in generating the rephrased version, but ensure that the final output text meets the requirements.
|
| 4 |
-
Use
|
| 5 |
|
| 6 |
---Goal---
|
| 7 |
To generate a version of the text that is rephrased and conveys the same meaning as the original entity and relationship descriptions, while:
|
|
@@ -52,7 +52,7 @@ To generate a version of the text that is rephrased and conveys the same meaning
|
|
| 52 |
|
| 53 |
ANSWER_REPHRASING_CONTEXT_ZH: str = """---角色---
|
| 54 |
你是一位NLP专家,负责根据下面提供的实体和关系生成逻辑结构清晰且连贯的文本重述版本。你可以参考原始文本辅助生成,但需要确保最终输出的文本符合要求。
|
| 55 |
-
|
| 56 |
|
| 57 |
---目标---
|
| 58 |
生成文本的重述版本,使其传达与原始实体和关系描述相同的含义,同时:
|
|
@@ -100,7 +100,7 @@ ANSWER_REPHRASING_CONTEXT_ZH: str = """---角色---
|
|
| 100 |
|
| 101 |
ANSWER_REPHRASING_EN: str = """---Role---
|
| 102 |
You are an NLP expert responsible for generating a logically structured and coherent rephrased version of the TEXT based on ENTITIES and RELATIONSHIPS provided below.
|
| 103 |
-
Use
|
| 104 |
|
| 105 |
---Goal---
|
| 106 |
To generate a version of the text that is rephrased and conveys the same meaning as the original entity and relationship descriptions, while:
|
|
@@ -146,7 +146,7 @@ To generate a version of the text that is rephrased and conveys the same meaning
|
|
| 146 |
|
| 147 |
ANSWER_REPHRASING_ZH: str = """---角色---
|
| 148 |
你是一位NLP专家,负责根据下面提供的实体和关系生成逻辑结构清晰且连贯的文本重述版本。
|
| 149 |
-
|
| 150 |
|
| 151 |
---目标---
|
| 152 |
生成文本的重述版本,使其传达与原始实体和关系描述相同的含义,同时:
|
|
|
|
| 1 |
# pylint: disable=C0301
|
| 2 |
ANSWER_REPHRASING_CONTEXT_EN: str = """---Role---
|
| 3 |
You are an NLP expert responsible for generating a logically structured and coherent rephrased version of the TEXT based on ENTITIES and RELATIONSHIPS provided below. You may refer to the original text to assist in generating the rephrased version, but ensure that the final output text meets the requirements.
|
| 4 |
+
Use English as output language.
|
| 5 |
|
| 6 |
---Goal---
|
| 7 |
To generate a version of the text that is rephrased and conveys the same meaning as the original entity and relationship descriptions, while:
|
|
|
|
| 52 |
|
| 53 |
ANSWER_REPHRASING_CONTEXT_ZH: str = """---角色---
|
| 54 |
你是一位NLP专家,负责根据下面提供的实体和关系生成逻辑结构清晰且连贯的文本重述版本。你可以参考原始文本辅助生成,但需要确保最终输出的文本符合要求。
|
| 55 |
+
使用中文作为输出语言。
|
| 56 |
|
| 57 |
---目标---
|
| 58 |
生成文本的重述版本,使其传达与原始实体和关系描述相同的含义,同时:
|
|
|
|
| 100 |
|
| 101 |
ANSWER_REPHRASING_EN: str = """---Role---
|
| 102 |
You are an NLP expert responsible for generating a logically structured and coherent rephrased version of the TEXT based on ENTITIES and RELATIONSHIPS provided below.
|
| 103 |
+
Use English as output language.
|
| 104 |
|
| 105 |
---Goal---
|
| 106 |
To generate a version of the text that is rephrased and conveys the same meaning as the original entity and relationship descriptions, while:
|
|
|
|
| 146 |
|
| 147 |
ANSWER_REPHRASING_ZH: str = """---角色---
|
| 148 |
你是一位NLP专家,负责根据下面提供的实体和关系生成逻辑结构清晰且连贯的文本重述版本。
|
| 149 |
+
使用中文作为输出语言。
|
| 150 |
|
| 151 |
---目标---
|
| 152 |
生成文本的重述版本,使其传达与原始实体和关系描述相同的含义,同时:
|
graphgen/templates/generation/vqa_generation.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pylint: disable=C0301
|
| 2 |
+
TEMPLATE_EN: str = """You are a senior VQA data engineer. Your task is to generate logically coherent, verifiable and non-hallucinated question-answer pairs for the given multi-modal samples.
|
| 3 |
+
Use English as the output language.
|
| 4 |
+
|
| 5 |
+
---Objectives---
|
| 6 |
+
Create multiple sets of VQA question-answer pairs that satisfy the following:
|
| 7 |
+
1. Only ask about objectively existing facts in the given data, avoiding subjective or ambiguous questions.
|
| 8 |
+
2. Ensure that each question has a clear and verifiable answer, avoiding questions with no answer or uncertainty.
|
| 9 |
+
3. Questions should cover various aspects of both image and text content, ensuring diversity and comprehensiveness.
|
| 10 |
+
4. Avoid repetitive questions, ensuring that each question is unique and meaningful.
|
| 11 |
+
5. Use clear and concise language, avoiding complex or ambiguous wording.
|
| 12 |
+
|
| 13 |
+
---Instructions---
|
| 14 |
+
1. Carefully analyze the provided entities and relationships to identify:
|
| 15 |
+
- Key concepts and their hierarchical relationships
|
| 16 |
+
- Temporal sequences and time order
|
| 17 |
+
- Cause-and-effect relationships
|
| 18 |
+
- Dependencies between different elements
|
| 19 |
+
2. Organize the information into a logical sequence by:
|
| 20 |
+
- Starting with foundational concepts
|
| 21 |
+
- Gradually building up to more complex relationships
|
| 22 |
+
- Grouping related ideas together
|
| 23 |
+
- Creating clear transitions between sections
|
| 24 |
+
3. Maintain the following when generating question-answer pairs:
|
| 25 |
+
- Logical flow
|
| 26 |
+
- Clear connections between concepts
|
| 27 |
+
- Appropriate context and background
|
| 28 |
+
- Coherent narrative structure
|
| 29 |
+
4. Review and refine the question-answer pairs to ensure:
|
| 30 |
+
- Overall logical consistency
|
| 31 |
+
- Clear cause-and-effect relationships
|
| 32 |
+
|
| 33 |
+
################
|
| 34 |
+
-Entities-
|
| 35 |
+
################
|
| 36 |
+
{entities}
|
| 37 |
+
################
|
| 38 |
+
-Relationships-
|
| 39 |
+
################
|
| 40 |
+
{relationships}
|
| 41 |
+
################
|
| 42 |
+
Directly output the generated questions and answers, please do not directly copy the example questions and answers, and do not provide irrelevant information.
|
| 43 |
+
Here is the response format you should follow:
|
| 44 |
+
Question: <Question1>
|
| 45 |
+
Answer: <Answer1>
|
| 46 |
+
|
| 47 |
+
Question: <Question2>
|
| 48 |
+
Answer: <Answer2>
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
TEMPLATE_ZH: str = """---角色---
|
| 53 |
+
你是一位资深 VQA 数据工程师。你需要为给定的多模态样本生成逻辑连贯、可验证、无幻觉的问答对。
|
| 54 |
+
使用中文作为输出语言。
|
| 55 |
+
|
| 56 |
+
---目标---
|
| 57 |
+
创建多组 VQA 问答对,满足:
|
| 58 |
+
1. 仅询问给定数据中客观存在的事实,避免主观或模糊的问题。
|
| 59 |
+
2. 确保每个问题都有明确且可验证的答案,避免无答案或不确定的问题。
|
| 60 |
+
3. 问题应涵盖图像和文本内容的各个方面,确保多样性和全面性。
|
| 61 |
+
4. 避免重复问题,确保每个问题都是独特且有意义的。
|
| 62 |
+
5. 使用清晰简洁的语言,避免复杂或含糊的措辞。
|
| 63 |
+
|
| 64 |
+
---说明---
|
| 65 |
+
1. 仔细分析提供的实体和关系,以识别:
|
| 66 |
+
- 关键概念及其层级关系
|
| 67 |
+
- 时间序列和时间顺序
|
| 68 |
+
- 因果关系
|
| 69 |
+
- 不同元素之间的依赖关系
|
| 70 |
+
2. 通过以下方式将信息组织成逻辑顺序:
|
| 71 |
+
- 从基础概念开始
|
| 72 |
+
- 逐步建立更复杂的关系
|
| 73 |
+
- 将相关的想法分组在一起
|
| 74 |
+
- 在各部分之间创建清晰的过渡
|
| 75 |
+
3. 生成问答对时保持:
|
| 76 |
+
- 逻辑流畅
|
| 77 |
+
- 概念之间的清晰联系
|
| 78 |
+
- 适当的上下文和背景
|
| 79 |
+
- 连贯的叙述结构
|
| 80 |
+
4. 检查和完善问答对以确保:
|
| 81 |
+
- 整体逻辑一致性
|
| 82 |
+
- 清晰的因果关系
|
| 83 |
+
|
| 84 |
+
################
|
| 85 |
+
-实体-
|
| 86 |
+
################
|
| 87 |
+
{entities}
|
| 88 |
+
|
| 89 |
+
################
|
| 90 |
+
-关系-
|
| 91 |
+
################
|
| 92 |
+
{relationships}
|
| 93 |
+
################
|
| 94 |
+
直接输出生成的问题和答案,请不要直接复制示例问题和答案,不要输出无关内容。
|
| 95 |
+
以下是你应该遵循的响应格式:
|
| 96 |
+
问题: <问题1>
|
| 97 |
+
答案: <答案1>
|
| 98 |
+
|
| 99 |
+
问题: <问题2>
|
| 100 |
+
答案: <答案2>
|
| 101 |
+
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
VQA_GENERATION_PROMPT = {"en": TEMPLATE_EN, "zh": TEMPLATE_ZH}
|
graphgen/templates/kg/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .kg_extraction import KG_EXTRACTION_PROMPT
|
| 2 |
+
from .kg_summarization import KG_SUMMARIZATION_PROMPT
|
| 3 |
+
from .mm_kg_extraction import MMKG_EXTRACTION_PROMPT
|
graphgen/templates/{kg_extraction.py → kg/kg_extraction.py}
RENAMED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
# pylint: disable=C0301
|
| 2 |
-
|
| 3 |
TEMPLATE_EN: str = """You are an NLP expert, skilled at analyzing text to extract named entities and their relationships.
|
| 4 |
|
| 5 |
-Goal-
|
| 6 |
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
| 7 |
-
Use
|
| 8 |
|
| 9 |
-Steps-
|
| 10 |
1. Identify all entities. For each identified entity, extract the following information:
|
|
@@ -23,7 +22,7 @@ Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tupl
|
|
| 23 |
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
|
| 24 |
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
|
| 25 |
|
| 26 |
-
4. Return output in
|
| 27 |
|
| 28 |
5. When finished, output {completion_delimiter}
|
| 29 |
|
|
@@ -85,7 +84,7 @@ TEMPLATE_ZH: str = """你是一个NLP专家,擅长分析文本提取命名实
|
|
| 85 |
|
| 86 |
-目标-
|
| 87 |
给定一个实体类型列表和可能与列表相关的文本,从文本中识别所有这些类型的实体,以及这些实体之间所有的关系。
|
| 88 |
-
|
| 89 |
|
| 90 |
-步骤-
|
| 91 |
1. 识别所有实体。对于每个识别的实体,提取以下信息:
|
|
@@ -189,12 +188,12 @@ Answer YES | NO if there are still entities and relationships that need to be ad
|
|
| 189 |
IF_LOOP_ZH: str = """看起来可能仍然遗漏了一些实体和关系。如果仍有实体和关系需要添加,请回答YES | NO。"""
|
| 190 |
|
| 191 |
KG_EXTRACTION_PROMPT: dict = {
|
| 192 |
-
"
|
| 193 |
"TEMPLATE": TEMPLATE_EN,
|
| 194 |
"CONTINUE": CONTINUE_EN,
|
| 195 |
"IF_LOOP": IF_LOOP_EN,
|
| 196 |
},
|
| 197 |
-
"
|
| 198 |
"TEMPLATE": TEMPLATE_ZH,
|
| 199 |
"CONTINUE": CONTINUE_ZH,
|
| 200 |
"IF_LOOP": IF_LOOP_ZH,
|
|
@@ -205,6 +204,5 @@ KG_EXTRACTION_PROMPT: dict = {
|
|
| 205 |
"completion_delimiter": "<|COMPLETE|>",
|
| 206 |
"entity_types": "concept, date, location, keyword, organization, person, event, work, nature, artificial, \
|
| 207 |
science, technology, mission, gene",
|
| 208 |
-
"language": "English",
|
| 209 |
},
|
| 210 |
}
|
|
|
|
| 1 |
# pylint: disable=C0301
|
|
|
|
| 2 |
TEMPLATE_EN: str = """You are an NLP expert, skilled at analyzing text to extract named entities and their relationships.
|
| 3 |
|
| 4 |
-Goal-
|
| 5 |
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
| 6 |
+
Use English as output language.
|
| 7 |
|
| 8 |
-Steps-
|
| 9 |
1. Identify all entities. For each identified entity, extract the following information:
|
|
|
|
| 22 |
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
|
| 23 |
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
|
| 24 |
|
| 25 |
+
4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
| 26 |
|
| 27 |
5. When finished, output {completion_delimiter}
|
| 28 |
|
|
|
|
| 84 |
|
| 85 |
-目标-
|
| 86 |
给定一个实体类型列表和可能与列表相关的文本,从文本中识别所有这些类型的实体,以及这些实体之间所有的关系。
|
| 87 |
+
使用中文作为输出语言。
|
| 88 |
|
| 89 |
-步骤-
|
| 90 |
1. 识别所有实体。对于每个识别的实体,提取以下信息:
|
|
|
|
| 188 |
IF_LOOP_ZH: str = """看起来可能仍然遗漏了一些实体和关系。如果仍有实体和关系需要添加,请回答YES | NO。"""
|
| 189 |
|
| 190 |
KG_EXTRACTION_PROMPT: dict = {
|
| 191 |
+
"en": {
|
| 192 |
"TEMPLATE": TEMPLATE_EN,
|
| 193 |
"CONTINUE": CONTINUE_EN,
|
| 194 |
"IF_LOOP": IF_LOOP_EN,
|
| 195 |
},
|
| 196 |
+
"zh": {
|
| 197 |
"TEMPLATE": TEMPLATE_ZH,
|
| 198 |
"CONTINUE": CONTINUE_ZH,
|
| 199 |
"IF_LOOP": IF_LOOP_ZH,
|
|
|
|
| 204 |
"completion_delimiter": "<|COMPLETE|>",
|
| 205 |
"entity_types": "concept, date, location, keyword, organization, person, event, work, nature, artificial, \
|
| 206 |
science, technology, mission, gene",
|
|
|
|
| 207 |
},
|
| 208 |
}
|
graphgen/templates/{kg_summarization.py → kg/kg_summarization.py}
RENAMED
|
@@ -3,7 +3,7 @@ Given one entity or relationship, and a list of descriptions, all related to the
|
|
| 3 |
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
|
| 4 |
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
|
| 5 |
Make sure it is written in third person, and include the entity names so we the have full context.
|
| 6 |
-
Use
|
| 7 |
|
| 8 |
#######
|
| 9 |
-Data-
|
|
@@ -18,7 +18,7 @@ TEMPLATE_ZH = """你是一个NLP专家,负责根据以下提供的数据生成
|
|
| 18 |
请将所有这些描述整合成一个综合描述。确保包含所有描述中收集的信息。
|
| 19 |
如果提供的描述是矛盾的,请解决这些矛盾并提供一个连贯的总结。
|
| 20 |
确保以第三人称写作,并包含实体名称,以便我们有完整的上下文。
|
| 21 |
-
|
| 22 |
|
| 23 |
#######
|
| 24 |
-数据-
|
|
@@ -30,14 +30,9 @@ TEMPLATE_ZH = """你是一个NLP专家,负责根据以下提供的数据生成
|
|
| 30 |
|
| 31 |
|
| 32 |
KG_SUMMARIZATION_PROMPT = {
|
| 33 |
-
"
|
| 34 |
-
|
| 35 |
-
},
|
| 36 |
-
"English": {
|
| 37 |
-
"TEMPLATE": TEMPLATE_EN
|
| 38 |
-
},
|
| 39 |
"FORMAT": {
|
| 40 |
-
"language": "English",
|
| 41 |
"tuple_delimiter": "<|>",
|
| 42 |
"record_delimiter": "##",
|
| 43 |
"completion_delimiter": "<|COMPLETE|>",
|
|
|
|
| 3 |
Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
|
| 4 |
If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
|
| 5 |
Make sure it is written in third person, and include the entity names so we the have full context.
|
| 6 |
+
Use English as output language.
|
| 7 |
|
| 8 |
#######
|
| 9 |
-Data-
|
|
|
|
| 18 |
请将所有这些描述整合成一个综合描述。确保包含所有描述中收集的信息。
|
| 19 |
如果提供的描述是矛盾的,请解决这些矛盾并提供一个连贯的总结。
|
| 20 |
确保以第三人称写作,并包含实体名称,以便我们有完整的上下文。
|
| 21 |
+
使用中文作为输出语言。
|
| 22 |
|
| 23 |
#######
|
| 24 |
-数据-
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
KG_SUMMARIZATION_PROMPT = {
|
| 33 |
+
"zh": {"TEMPLATE": TEMPLATE_ZH},
|
| 34 |
+
"en": {"TEMPLATE": TEMPLATE_EN},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
"FORMAT": {
|
|
|
|
| 36 |
"tuple_delimiter": "<|>",
|
| 37 |
"record_delimiter": "##",
|
| 38 |
"completion_delimiter": "<|COMPLETE|>",
|
graphgen/templates/kg/mm_kg_extraction.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pylint: disable=C0301
|
| 2 |
+
TEMPLATE_EN: str = """You are an expert in multi-modal data analysis and knowledge graph construction. Your task is to extract named entities and relationships from a given multi-modal data chunk and its accompanying text.
|
| 3 |
+
|
| 4 |
+
-Objective-
|
| 5 |
+
Given a multi-modal data chunk (e.g., image, table, formula, etc. + accompanying text), construct a knowledge graph centered around the "central multi-modal entity":
|
| 6 |
+
- The central entity must be the image/table/formula itself (e.g., image-c71ef797e99af81047fbc7509609c765).
|
| 7 |
+
- Related entities and relationships must be extracted from the accompanying text.
|
| 8 |
+
- Only retain edges directly connected to the central entity, forming a star-shaped graph.
|
| 9 |
+
Use English as the output language.
|
| 10 |
+
|
| 11 |
+
-Steps-
|
| 12 |
+
1. Identify the unique central multi-modal entity and recognize all text entities directly related to the central entity from the accompanying text.
|
| 13 |
+
For the central entity, extract the following information:
|
| 14 |
+
- entity_name: Use the unique identifier of the data chunk (e.g., image-c71ef797e99af81047fbc7509609c765).
|
| 15 |
+
- entity_type: Label according to the type of data chunk (image, table, formula, etc.).
|
| 16 |
+
- entity_summary: A brief description of the content of the data chunk and its role in the accompanying text.
|
| 17 |
+
For each entity recognized from the accompanying text, extract the following information:
|
| 18 |
+
- entity_name: The name of the entity, capitalized
|
| 19 |
+
- entity_type: One of the following types: [{entity_types}]
|
| 20 |
+
- entity_summary: A comprehensive summary of the entity's attributes and activities
|
| 21 |
+
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_summary>)
|
| 22 |
+
|
| 23 |
+
2. From the entities identified in Step 1, recognize all (source_entity, target_entity) pairs that are *obviously related* to each other.
|
| 24 |
+
For each pair of related entities, extract the following information:
|
| 25 |
+
- source_entity: The name of the source entity identified in Step 1
|
| 26 |
+
- target_entity: The name of the target entity identified in Step 1
|
| 27 |
+
- relationship_summary: Explain why you think the source entity and target entity are related to each other
|
| 28 |
+
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_summary>)
|
| 29 |
+
|
| 30 |
+
3. Return the output list of all entities and relationships identified in Steps 1 and 2 in English. Use **{record_delimiter}** as the list separator.
|
| 31 |
+
|
| 32 |
+
4. Upon completion, output {completion_delimiter}
|
| 33 |
+
|
| 34 |
+
################
|
| 35 |
+
-Example-
|
| 36 |
+
################
|
| 37 |
+
Multi-modal data chunk type: image
|
| 38 |
+
Multi-modal data chunk unique identifier: image-c71ef797e99af81047fbc7509609c765
|
| 39 |
+
Accompanying text: The Eiffel Tower is an iconic structure in Paris, France, designed by Gustave Eiffel and completed in 1889. It stands 324 meters tall and is one of the tallest structures in the world. The Eiffel Tower is located on the banks of the Seine River and attracts millions of visitors each year. It is not only an engineering marvel but also an important symbol of French culture.
|
| 40 |
+
################
|
| 41 |
+
Output:
|
| 42 |
+
("entity"{tuple_delimiter}"image-c71ef797e99af81047fbc7509609c765"{tuple_delimiter}"image"{tuple_delimiter}"This is an image showcasing the iconic structure in Paris, France, the Eiffel Tower, highlighting its full height of 324 meters along with the riverside scenery, symbolizing both engineering and cultural significance"){record_delimiter}
|
| 43 |
+
("entity"{tuple_delimiter}"Eiffel Tower"{tuple_delimiter}"landmark"{tuple_delimiter}"The Eiffel Tower is an iconic structure in Paris, France, designed by Gustave Eiffel and completed in 1889, standing 324 meters tall, located on the banks of the Seine River, attracting millions of visitors each year"){record_delimiter}
|
| 44 |
+
("entity"{tuple_delimiter}"Paris, France"{tuple_delimiter}"location"{tuple_delimiter}"Paris, France is the capital of France, known for its rich historical and cultural heritage and as the location of the Eiffel Tower"){record_delimiter}
|
| 45 |
+
("entity"{tuple_delimiter}"Gustave Eiffel"{tuple_delimiter}"person"{tuple_delimiter}"Gustave Eiffel is a renowned French engineer who designed and built the Eiffel Tower"){record_delimiter}
|
| 46 |
+
("entity"{tuple_delimiter}"Seine River"{tuple_delimiter}"location"{tuple_delimiter}"The Seine River is a major river flowing through Paris, France, with the Eiffel Tower located on its banks"){completion_delimiter}
|
| 47 |
+
("relationship"{tuple_delimiter}"image-c71ef797e99af81047fbc7509609c765"{tuple_delimiter}"Eiffel Tower"{tuple_delimiter}"The image showcases the iconic structure, the Eiffel Tower"){record_delimiter}
|
| 48 |
+
("relationship"{tuple_delimiter}"image-c71ef797e99af81047fbc7509609c765"{tuple_delimiter}"Paris, France"{tuple_delimiter}"The image's background is Paris, France, highlighting the geographical location of the Eiffel Tower"){record_delimiter}
|
| 49 |
+
("relationship"{tuple_delimiter}"image-c71ef797e99af81047fbc7509609c765"{tuple_delimiter}"Gustave Eiffel"{tuple_delimiter}"The Eiffel Tower in the image was designed by Gustave Eiffel"){record_delimiter}
|
| 50 |
+
("relationship"{tuple_delimiter}"image-c71ef797e99af81047fbc7509609c765"{tuple_delimiter}"Seine River"{tuple_delimiter}"The image showcases the scenery of the Eiffel Tower located on the banks of the Seine River"){completion_delimiter}
|
| 51 |
+
################
|
| 52 |
+
|
| 53 |
+
-Real Data-
|
| 54 |
+
Multi-modal data chunk type: {chunk_type}
|
| 55 |
+
Multi-modal data chunk unique identifier: {chunk_id}
|
| 56 |
+
Accompanying text: {chunk_text}
|
| 57 |
+
################
|
| 58 |
+
Output:
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
TEMPLATE_ZH: str = """你是一个多模态数据分析和知识图谱构建专家。你的任务是从给定的多模态数据块及其伴随文本中抽取命名实体与关系。
|
| 62 |
+
|
| 63 |
+
-目标-
|
| 64 |
+
给定一个多模态数据块(例如图像、表格、公式等 + 伴随文本),构建以「中心多模态实体」为核心的知识图:
|
| 65 |
+
- 中心实体必须是图像/表格/公式本身(如 image-c71ef797e99af81047fbc7509609c765)。
|
| 66 |
+
- 相关实体和关系必须从伴随文本中抽取。
|
| 67 |
+
- 只保留与中心实体直接相连的边,形成星型图。
|
| 68 |
+
使用中文作为输出语言。
|
| 69 |
+
|
| 70 |
+
-步骤-
|
| 71 |
+
1. 确定唯一的中心多模态实体,从伴随文本中识别所有与中心实体直接相关的文本实体。
|
| 72 |
+
对于中心实体,提取以下信息:
|
| 73 |
+
- entity_name:使用数据块的唯一标识符(如 image-c71ef797e99af81047fbc7509609c765)。
|
| 74 |
+
- entity_type:根据数据块类型(图像、表格、公式等)进行标注。
|
| 75 |
+
- entity_summary:简要描述数据块的内容和其在伴随文本中的作用。
|
| 76 |
+
对于从伴随文本中识别的每个实体,提取以下信息:
|
| 77 |
+
- entity_name:实体的名称,首字母大写
|
| 78 |
+
- entity_type:以下类型之一:[{entity_types}]
|
| 79 |
+
- entity_summary:实体的属性与活动的全面总结
|
| 80 |
+
将每个实体格式化为("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_summary>)
|
| 81 |
+
|
| 82 |
+
2. 从步骤1中识别的实体中,识别所有(源实体,目标实体)对,这些实体彼此之间*明显相关*。
|
| 83 |
+
对于每对相关的实体,提取以下信息:
|
| 84 |
+
- source_entity:步骤1中识别的源实体名称
|
| 85 |
+
- target_entity:步骤1中识别的目标实体名称
|
| 86 |
+
- relationship_summary:解释为什么你认为源实体和目标实体彼此相关
|
| 87 |
+
将每个关系格式化为("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_summary>)
|
| 88 |
+
|
| 89 |
+
3. 以中文返回步骤1和2中识别出的所有实体和关系的输出列表。使用**{record_delimiter}**作为列表分隔符。
|
| 90 |
+
|
| 91 |
+
4. 完成后,输出{completion_delimiter}
|
| 92 |
+
|
| 93 |
+
################
|
| 94 |
+
-示例-
|
| 95 |
+
################
|
| 96 |
+
多模态数据块类型:image
|
| 97 |
+
多模态数据块唯一标识符:image-c71ef797e99af81047fbc7509609c765
|
| 98 |
+
伴随文本:埃菲尔铁塔是法国巴黎的标志性结构,由古斯塔夫·埃菲尔设计并于1889年建成。它高324米,是世界上最高的建筑之一。埃菲尔铁塔位于塞纳河畔,吸引了数百万游客前来参观。它不仅是工程学的奇迹,也是法国文化的重要象征。
|
| 99 |
+
################
|
| 100 |
+
输出:
|
| 101 |
+
("entity"{tuple_delimiter}"image-c71ef797e99af81047fbc7509609c765"{tuple_delimiter}"image"{tuple_delimiter}"这是一张展示法国巴黎标志性建筑的图像,主体为埃菲尔铁塔,呈现其324米高度的全貌与河畔景观,具有工程与文化双重象征意义"){record_delimiter}
|
| 102 |
+
("entity"{tuple_delimiter}"埃菲尔铁塔"{tuple_delimiter}"landmark"{tuple_delimiter}"埃菲尔铁塔是法国巴黎的标志性结构,由古斯塔夫·埃菲尔设计并于1889年建成,高324米,是世界上最高的建筑之一,位于塞纳河畔,吸引了数百万游客前来参观"){record_delimiter}
|
| 103 |
+
("entity"{tuple_delimiter}"法国巴黎"{tuple_delimiter}"location"{tuple_delimiter}"法国巴黎是法国的首都,以其丰富的历史文化遗产和作为埃菲尔铁塔所在地而闻名"){record_delimiter}
|
| 104 |
+
("entity"{tuple_delimiter}"古斯塔夫·埃菲尔"{tuple_delimiter}"person"{tuple_delimiter}"古斯塔夫·埃菲尔是法国著名的工程师,设计并建造了埃菲尔铁塔"){record_delimiter}
|
| 105 |
+
("entity"{tuple_delimiter}"塞纳河"{tuple_delimiter}"location"{tuple_delimiter}"塞纳河是流经法国巴黎的重要河流,埃菲尔铁塔位于其畔"){completion_delimiter}
|
| 106 |
+
("relationship"{tuple_delimiter}"image-c71ef797e99af81047fbc7509609c765"{tuple_delimiter}"埃菲尔铁塔"{tuple_delimiter}"图像展示了埃菲尔铁塔这一标志性建筑"){record_delimiter}
|
| 107 |
+
("relationship"{tuple_delimiter}"image-c71ef797e99af81047fbc7509609c765"{tuple_delimiter}"法国巴黎"{tuple_delimiter}"图像背景为法国巴黎,突显了埃菲尔铁塔的地理位置"){record_delimiter}
|
| 108 |
+
("relationship"{tuple_delimiter}"image-c71ef797e99af81047fbc7509609c765"{tuple_delimiter}"古斯塔夫·埃菲尔"{tuple_delimiter}"图像中的埃菲尔铁塔是由古斯塔夫·埃菲尔设计的"){record_delimiter}
|
| 109 |
+
("relationship"{tuple_delimiter}"image-c71ef797e99af81047fbc7509609c765"{tuple_delimiter}"塞纳河"{tuple_delimiter}"���像展示了埃菲尔铁塔位于塞纳河畔的景观"){completion_delimiter}
|
| 110 |
+
################
|
| 111 |
+
|
| 112 |
+
-真实数据-
|
| 113 |
+
多模态数据块类型: {chunk_type}
|
| 114 |
+
多模态数据块唯一标识符: {chunk_id}
|
| 115 |
+
伴随文本: {chunk_text}
|
| 116 |
+
################
|
| 117 |
+
输出:
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
MMKG_EXTRACTION_PROMPT: dict = {
|
| 122 |
+
"en": TEMPLATE_EN,
|
| 123 |
+
"zh": TEMPLATE_ZH,
|
| 124 |
+
"FORMAT": {
|
| 125 |
+
"tuple_delimiter": "<|>",
|
| 126 |
+
"record_delimiter": "##",
|
| 127 |
+
"completion_delimiter": "<|COMPLETE|>",
|
| 128 |
+
"entity_types": "concept, date, location, keyword, organization, person, event, work, nature, artificial, \
|
| 129 |
+
science, technology, mission, gene",
|
| 130 |
+
},
|
| 131 |
+
}
|
graphgen/utils/__init__.py
CHANGED
|
@@ -9,7 +9,7 @@ from .format import (
|
|
| 9 |
split_string_by_multi_markers,
|
| 10 |
write_json,
|
| 11 |
)
|
| 12 |
-
from .hash import compute_args_hash, compute_content_hash
|
| 13 |
from .help_nltk import NLTKHelper
|
| 14 |
from .log import logger, parse_log, set_logger
|
| 15 |
from .loop import create_event_loop
|
|
|
|
| 9 |
split_string_by_multi_markers,
|
| 10 |
write_json,
|
| 11 |
)
|
| 12 |
+
from .hash import compute_args_hash, compute_content_hash, compute_mm_hash
|
| 13 |
from .help_nltk import NLTKHelper
|
| 14 |
from .log import logger, parse_log, set_logger
|
| 15 |
from .loop import create_event_loop
|
graphgen/utils/detect_lang.py
CHANGED
|
@@ -1,40 +1,41 @@
|
|
| 1 |
def detect_main_language(text):
|
| 2 |
"""
|
| 3 |
-
|
| 4 |
|
| 5 |
:param text:
|
| 6 |
:return:
|
| 7 |
"""
|
| 8 |
assert isinstance(text, str)
|
|
|
|
| 9 |
def is_chinese_char(char):
|
| 10 |
-
return
|
| 11 |
|
| 12 |
def is_english_char(char):
|
| 13 |
return char.isascii() and char.isalpha()
|
| 14 |
|
| 15 |
-
|
| 16 |
-
text = ''.join(char for char in text if char.strip())
|
| 17 |
|
| 18 |
chinese_count = sum(1 for char in text if is_chinese_char(char))
|
| 19 |
english_count = sum(1 for char in text if is_english_char(char))
|
| 20 |
|
| 21 |
total = chinese_count + english_count
|
| 22 |
if total == 0:
|
| 23 |
-
return
|
| 24 |
|
| 25 |
chinese_ratio = chinese_count / total
|
| 26 |
|
| 27 |
if chinese_ratio >= 0.5:
|
| 28 |
-
return
|
| 29 |
-
return
|
|
|
|
| 30 |
|
| 31 |
def detect_if_chinese(text):
|
| 32 |
"""
|
| 33 |
-
|
| 34 |
|
| 35 |
:param text:
|
| 36 |
:return:
|
| 37 |
"""
|
| 38 |
|
| 39 |
assert isinstance(text, str)
|
| 40 |
-
return any(
|
|
|
|
| 1 |
def detect_main_language(text):
|
| 2 |
"""
|
| 3 |
+
Detect the main language of the text, 'zh' for Chinese, 'en' for English
|
| 4 |
|
| 5 |
:param text:
|
| 6 |
:return:
|
| 7 |
"""
|
| 8 |
assert isinstance(text, str)
|
| 9 |
+
|
| 10 |
def is_chinese_char(char):
|
| 11 |
+
return "\u4e00" <= char <= "\u9fff"
|
| 12 |
|
| 13 |
def is_english_char(char):
|
| 14 |
return char.isascii() and char.isalpha()
|
| 15 |
|
| 16 |
+
text = "".join(char for char in text if char.strip())
|
|
|
|
| 17 |
|
| 18 |
chinese_count = sum(1 for char in text if is_chinese_char(char))
|
| 19 |
english_count = sum(1 for char in text if is_english_char(char))
|
| 20 |
|
| 21 |
total = chinese_count + english_count
|
| 22 |
if total == 0:
|
| 23 |
+
return "en"
|
| 24 |
|
| 25 |
chinese_ratio = chinese_count / total
|
| 26 |
|
| 27 |
if chinese_ratio >= 0.5:
|
| 28 |
+
return "zh"
|
| 29 |
+
return "en"
|
| 30 |
+
|
| 31 |
|
| 32 |
def detect_if_chinese(text):
|
| 33 |
"""
|
| 34 |
+
Detect if the text contains any Chinese characters
|
| 35 |
|
| 36 |
:param text:
|
| 37 |
:return:
|
| 38 |
"""
|
| 39 |
|
| 40 |
assert isinstance(text, str)
|
| 41 |
+
return any("\u4e00" <= char <= "\u9fff" for char in text)
|
graphgen/utils/hash.py
CHANGED
|
@@ -1,7 +1,23 @@
|
|
| 1 |
from hashlib import md5
|
| 2 |
|
|
|
|
| 3 |
def compute_args_hash(*args):
|
| 4 |
return md5(str(args).encode()).hexdigest()
|
| 5 |
|
|
|
|
| 6 |
def compute_content_hash(content, prefix: str = ""):
|
| 7 |
return prefix + md5(content.encode()).hexdigest()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from hashlib import md5
|
| 2 |
|
| 3 |
+
|
| 4 |
def compute_args_hash(*args):
|
| 5 |
return md5(str(args).encode()).hexdigest()
|
| 6 |
|
| 7 |
+
|
| 8 |
def compute_content_hash(content, prefix: str = ""):
|
| 9 |
return prefix + md5(content.encode()).hexdigest()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def compute_mm_hash(item, prefix: str = ""):
|
| 13 |
+
if item.get("type") == "text" and item.get("text"):
|
| 14 |
+
content = item["text"].strip()
|
| 15 |
+
elif item.get("type") == "image" and item.get("img_path"):
|
| 16 |
+
content = f"image:{item['img_path']}"
|
| 17 |
+
elif item.get("type") == "table" and item.get("table_body"):
|
| 18 |
+
content = f"table:{item['table_body']}"
|
| 19 |
+
elif item.get("type") == "equation" and item.get("text"):
|
| 20 |
+
content = f"equation:{item['text']}"
|
| 21 |
+
else:
|
| 22 |
+
content = str(item)
|
| 23 |
+
return prefix + md5(content.encode()).hexdigest()
|
graphgen/utils/log.py
CHANGED
|
@@ -8,7 +8,8 @@ logger = logging.getLogger("graphgen")
|
|
| 8 |
|
| 9 |
def set_logger(
|
| 10 |
log_file: str,
|
| 11 |
-
|
|
|
|
| 12 |
*,
|
| 13 |
if_stream: bool = True,
|
| 14 |
max_bytes: int = 50 * 1024 * 1024, # 50 MB
|
|
@@ -22,14 +23,18 @@ def set_logger(
|
|
| 22 |
if force:
|
| 23 |
logger.handlers.clear()
|
| 24 |
|
| 25 |
-
logger.setLevel(
|
|
|
|
|
|
|
| 26 |
logger.propagate = False
|
| 27 |
|
| 28 |
if logger.handlers:
|
| 29 |
logger.handlers.clear()
|
| 30 |
|
| 31 |
if if_stream:
|
| 32 |
-
console = RichHandler(
|
|
|
|
|
|
|
| 33 |
console.setFormatter(logging.Formatter("%(message)s"))
|
| 34 |
logger.addHandler(console)
|
| 35 |
|
|
@@ -39,7 +44,7 @@ def set_logger(
|
|
| 39 |
backupCount=backup_count,
|
| 40 |
encoding="utf-8",
|
| 41 |
)
|
| 42 |
-
file_handler.setLevel(
|
| 43 |
file_handler.setFormatter(
|
| 44 |
logging.Formatter(
|
| 45 |
"[%(asctime)s] %(levelname)s [%(name)s:%(filename)s:%(lineno)d] %(message)s",
|
|
|
|
| 8 |
|
| 9 |
def set_logger(
|
| 10 |
log_file: str,
|
| 11 |
+
file_level: int = logging.DEBUG,
|
| 12 |
+
console_level: int = logging.INFO,
|
| 13 |
*,
|
| 14 |
if_stream: bool = True,
|
| 15 |
max_bytes: int = 50 * 1024 * 1024, # 50 MB
|
|
|
|
| 23 |
if force:
|
| 24 |
logger.handlers.clear()
|
| 25 |
|
| 26 |
+
logger.setLevel(
|
| 27 |
+
min(file_level, console_level)
|
| 28 |
+
) # Set to the lowest level to capture all logs
|
| 29 |
logger.propagate = False
|
| 30 |
|
| 31 |
if logger.handlers:
|
| 32 |
logger.handlers.clear()
|
| 33 |
|
| 34 |
if if_stream:
|
| 35 |
+
console = RichHandler(
|
| 36 |
+
level=console_level, show_path=False, rich_tracebacks=True
|
| 37 |
+
)
|
| 38 |
console.setFormatter(logging.Formatter("%(message)s"))
|
| 39 |
logger.addHandler(console)
|
| 40 |
|
|
|
|
| 44 |
backupCount=backup_count,
|
| 45 |
encoding="utf-8",
|
| 46 |
)
|
| 47 |
+
file_handler.setLevel(file_level)
|
| 48 |
file_handler.setFormatter(
|
| 49 |
logging.Formatter(
|
| 50 |
"[%(asctime)s] %(levelname)s [%(name)s:%(filename)s:%(lineno)d] %(message)s",
|