Spaces:
Sleeping
Sleeping
github-actions[bot]
commited on
Commit
·
799ac7c
1
Parent(s):
37f0321
Auto-sync from demo at Wed Oct 15 06:28:02 UTC 2025
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +1 -1
- graphgen/bases/__init__.py +2 -0
- graphgen/bases/base_generator.py +84 -0
- graphgen/bases/base_partitioner.py +76 -0
- graphgen/bases/base_storage.py +2 -2
- graphgen/bases/datatypes.py +8 -0
- graphgen/configs/aggregated_config.yaml +4 -8
- graphgen/configs/atomic_config.yaml +2 -9
- graphgen/configs/cot_config.yaml +3 -3
- graphgen/configs/multi_hop_config.yaml +4 -8
- graphgen/graphgen.py +14 -52
- graphgen/models/__init__.py +15 -8
- graphgen/models/community/__init__.py +0 -0
- graphgen/models/community/community_detector.py +0 -95
- graphgen/models/evaluate/__init__.py +0 -0
- graphgen/models/evaluator/__init__.py +4 -0
- graphgen/models/{evaluate → evaluator}/base_evaluator.py +0 -0
- graphgen/models/{evaluate → evaluator}/length_evaluator.py +1 -1
- graphgen/models/{evaluate → evaluator}/mtld_evaluator.py +1 -1
- graphgen/models/{evaluate → evaluator}/reward_evaluator.py +0 -0
- graphgen/models/{evaluate → evaluator}/uni_evaluator.py +0 -0
- graphgen/models/generator/__init__.py +4 -0
- graphgen/models/generator/aggregated_generator.py +127 -0
- graphgen/models/generator/atomic_generator.py +52 -0
- graphgen/models/generator/cot_generator.py +122 -0
- graphgen/models/generator/multi_hop_generator.py +55 -0
- graphgen/models/kg_builder/__init__.py +1 -0
- graphgen/models/llm/limitter.py +27 -29
- graphgen/models/llm/openai_client.py +4 -2
- graphgen/models/partitioner/__init__.py +4 -0
- graphgen/models/partitioner/bfs_partitioner.py +83 -0
- graphgen/models/partitioner/dfs_partitioner.py +80 -0
- graphgen/models/partitioner/ece_partitioner.py +163 -0
- graphgen/models/partitioner/leiden_partitioner.py +120 -0
- graphgen/models/storage/__init__.py +2 -0
- graphgen/models/storage/networkx_storage.py +4 -4
- graphgen/operators/__init__.py +4 -9
- graphgen/operators/build_kg/__init__.py +1 -0
- graphgen/operators/build_kg/split_kg.py +0 -382
- graphgen/operators/generate/__init__.py +1 -0
- graphgen/operators/generate/generate_cot.py +0 -117
- graphgen/operators/generate/generate_qas.py +58 -0
- graphgen/operators/partition/__init__.py +1 -0
- graphgen/operators/partition/partition_kg.py +48 -0
- graphgen/operators/partition/pre_tokenize.py +47 -0
- graphgen/operators/search/__init__.py +1 -0
- graphgen/operators/traverse_graph.py +0 -540
- graphgen/templates/__init__.py +6 -3
- graphgen/templates/community/__init__.py +0 -2
- graphgen/templates/community/cot_generation.py +0 -87
app.py
CHANGED
|
@@ -468,7 +468,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 468 |
label="TPM",
|
| 469 |
minimum=5000,
|
| 470 |
maximum=5000000,
|
| 471 |
-
value=
|
| 472 |
step=1000,
|
| 473 |
interactive=True,
|
| 474 |
visible=True,
|
|
|
|
| 468 |
label="TPM",
|
| 469 |
minimum=5000,
|
| 470 |
maximum=5000000,
|
| 471 |
+
value=100000,
|
| 472 |
step=1000,
|
| 473 |
interactive=True,
|
| 474 |
visible=True,
|
graphgen/bases/__init__.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
|
|
| 1 |
from .base_kg_builder import BaseKGBuilder
|
| 2 |
from .base_llm_client import BaseLLMClient
|
|
|
|
| 3 |
from .base_reader import BaseReader
|
| 4 |
from .base_splitter import BaseSplitter
|
| 5 |
from .base_storage import (
|
|
|
|
| 1 |
+
from .base_generator import BaseGenerator
|
| 2 |
from .base_kg_builder import BaseKGBuilder
|
| 3 |
from .base_llm_client import BaseLLMClient
|
| 4 |
+
from .base_partitioner import BasePartitioner
|
| 5 |
from .base_reader import BaseReader
|
| 6 |
from .base_splitter import BaseSplitter
|
| 7 |
from .base_storage import (
|
graphgen/bases/base_generator.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
from graphgen.bases.base_llm_client import BaseLLMClient
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class BaseGenerator(ABC):
|
| 10 |
+
"""
|
| 11 |
+
Generate QAs based on given prompts.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
llm_client: BaseLLMClient
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def build_prompt(
|
| 19 |
+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
|
| 20 |
+
) -> str:
|
| 21 |
+
"""Build prompt for LLM based on the given batch"""
|
| 22 |
+
|
| 23 |
+
@staticmethod
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def parse_response(response: str) -> Any:
|
| 26 |
+
"""Parse the LLM response and return the generated QAs"""
|
| 27 |
+
|
| 28 |
+
async def generate(
|
| 29 |
+
self,
|
| 30 |
+
batch: tuple[
|
| 31 |
+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 32 |
+
],
|
| 33 |
+
) -> dict[str, Any]:
|
| 34 |
+
"""
|
| 35 |
+
Generate QAs based on a given batch.
|
| 36 |
+
:param batch
|
| 37 |
+
:return: QA pairs
|
| 38 |
+
"""
|
| 39 |
+
result = {}
|
| 40 |
+
prompt = self.build_prompt(batch)
|
| 41 |
+
response = await self.llm_client.generate_answer(prompt)
|
| 42 |
+
qa_pairs = self.parse_response(response) # generate one or more QA pairs
|
| 43 |
+
result.update(qa_pairs)
|
| 44 |
+
return result
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def format_generation_results(
|
| 48 |
+
results: list[dict], output_data_format: str
|
| 49 |
+
) -> list[dict[str, Any]]:
|
| 50 |
+
if output_data_format == "Alpaca":
|
| 51 |
+
results = [
|
| 52 |
+
{
|
| 53 |
+
"instruction": v["question"],
|
| 54 |
+
"input": "",
|
| 55 |
+
"output": v["answer"],
|
| 56 |
+
}
|
| 57 |
+
for item in results
|
| 58 |
+
for k, v in item.items()
|
| 59 |
+
]
|
| 60 |
+
elif output_data_format == "Sharegpt":
|
| 61 |
+
results = [
|
| 62 |
+
{
|
| 63 |
+
"conversations": [
|
| 64 |
+
{"from": "human", "value": v["question"]},
|
| 65 |
+
{"from": "gpt", "value": v["answer"]},
|
| 66 |
+
]
|
| 67 |
+
}
|
| 68 |
+
for item in results
|
| 69 |
+
for k, v in item.items()
|
| 70 |
+
]
|
| 71 |
+
elif output_data_format == "ChatML":
|
| 72 |
+
results = [
|
| 73 |
+
{
|
| 74 |
+
"messages": [
|
| 75 |
+
{"role": "user", "content": v["question"]},
|
| 76 |
+
{"role": "assistant", "content": v["answer"]},
|
| 77 |
+
]
|
| 78 |
+
}
|
| 79 |
+
for item in results
|
| 80 |
+
for k, v in item.items()
|
| 81 |
+
]
|
| 82 |
+
else:
|
| 83 |
+
raise ValueError(f"Unknown output data format: {output_data_format}")
|
| 84 |
+
return results
|
graphgen/bases/base_partitioner.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, List
|
| 4 |
+
|
| 5 |
+
from graphgen.bases.base_storage import BaseGraphStorage
|
| 6 |
+
from graphgen.bases.datatypes import Community
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class BasePartitioner(ABC):
|
| 11 |
+
@abstractmethod
|
| 12 |
+
async def partition(
|
| 13 |
+
self,
|
| 14 |
+
g: BaseGraphStorage,
|
| 15 |
+
**kwargs: Any,
|
| 16 |
+
) -> List[Community]:
|
| 17 |
+
"""
|
| 18 |
+
Graph -> Communities
|
| 19 |
+
:param g: Graph storage instance
|
| 20 |
+
:param kwargs: Additional parameters for partitioning
|
| 21 |
+
:return: List of communities
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
@staticmethod
|
| 25 |
+
async def community2batch(
|
| 26 |
+
communities: List[Community], g: BaseGraphStorage
|
| 27 |
+
) -> list[
|
| 28 |
+
tuple[
|
| 29 |
+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 30 |
+
]
|
| 31 |
+
]:
|
| 32 |
+
"""
|
| 33 |
+
Convert communities to batches of nodes and edges.
|
| 34 |
+
:param communities
|
| 35 |
+
:param g: Graph storage instance
|
| 36 |
+
:return: List of batches, each batch is a tuple of (nodes, edges)
|
| 37 |
+
"""
|
| 38 |
+
batches = []
|
| 39 |
+
for comm in communities:
|
| 40 |
+
nodes = comm.nodes
|
| 41 |
+
edges = comm.edges
|
| 42 |
+
nodes_data = []
|
| 43 |
+
for node in nodes:
|
| 44 |
+
node_data = await g.get_node(node)
|
| 45 |
+
if node_data:
|
| 46 |
+
nodes_data.append((node, node_data))
|
| 47 |
+
edges_data = []
|
| 48 |
+
for u, v in edges:
|
| 49 |
+
edge_data = await g.get_edge(u, v)
|
| 50 |
+
if edge_data:
|
| 51 |
+
edges_data.append((u, v, edge_data))
|
| 52 |
+
else:
|
| 53 |
+
edge_data = await g.get_edge(v, u)
|
| 54 |
+
if edge_data:
|
| 55 |
+
edges_data.append((v, u, edge_data))
|
| 56 |
+
batches.append((nodes_data, edges_data))
|
| 57 |
+
return batches
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def _build_adjacency_list(
|
| 61 |
+
nodes: List[tuple[str, dict]], edges: List[tuple[str, str, dict]]
|
| 62 |
+
) -> tuple[dict[str, List[str]], set[tuple[str, str]]]:
|
| 63 |
+
"""
|
| 64 |
+
Build adjacency list and edge set from nodes and edges.
|
| 65 |
+
:param nodes
|
| 66 |
+
:param edges
|
| 67 |
+
:return: adjacency list, edge set
|
| 68 |
+
"""
|
| 69 |
+
adj: dict[str, List[str]] = {n[0]: [] for n in nodes}
|
| 70 |
+
edge_set: set[tuple[str, str]] = set()
|
| 71 |
+
for e in edges:
|
| 72 |
+
adj[e[0]].append(e[1])
|
| 73 |
+
adj[e[1]].append(e[0])
|
| 74 |
+
edge_set.add((e[0], e[1]))
|
| 75 |
+
edge_set.add((e[1], e[0]))
|
| 76 |
+
return adj, edge_set
|
graphgen/bases/base_storage.py
CHANGED
|
@@ -78,7 +78,7 @@ class BaseGraphStorage(StorageNameSpace):
|
|
| 78 |
async def update_node(self, node_id: str, node_data: dict[str, str]):
|
| 79 |
raise NotImplementedError
|
| 80 |
|
| 81 |
-
async def get_all_nodes(self) -> Union[list[dict], None]:
|
| 82 |
raise NotImplementedError
|
| 83 |
|
| 84 |
async def get_edge(
|
|
@@ -91,7 +91,7 @@ class BaseGraphStorage(StorageNameSpace):
|
|
| 91 |
):
|
| 92 |
raise NotImplementedError
|
| 93 |
|
| 94 |
-
async def get_all_edges(self) -> Union[list[dict], None]:
|
| 95 |
raise NotImplementedError
|
| 96 |
|
| 97 |
async def get_node_edges(
|
|
|
|
| 78 |
async def update_node(self, node_id: str, node_data: dict[str, str]):
|
| 79 |
raise NotImplementedError
|
| 80 |
|
| 81 |
+
async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
|
| 82 |
raise NotImplementedError
|
| 83 |
|
| 84 |
async def get_edge(
|
|
|
|
| 91 |
):
|
| 92 |
raise NotImplementedError
|
| 93 |
|
| 94 |
+
async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
|
| 95 |
raise NotImplementedError
|
| 96 |
|
| 97 |
async def get_node_edges(
|
graphgen/bases/datatypes.py
CHANGED
|
@@ -30,3 +30,11 @@ class Token:
|
|
| 30 |
@property
|
| 31 |
def logprob(self) -> float:
|
| 32 |
return math.log(self.prob)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
@property
|
| 31 |
def logprob(self) -> float:
|
| 32 |
return math.log(self.prob)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class Community:
|
| 37 |
+
id: Union[int, str]
|
| 38 |
+
nodes: List[str] = field(default_factory=list)
|
| 39 |
+
edges: List[tuple] = field(default_factory=list)
|
| 40 |
+
metadata: dict = field(default_factory=dict)
|
graphgen/configs/aggregated_config.yaml
CHANGED
|
@@ -13,14 +13,10 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
|
| 13 |
partition: # graph partition configuration
|
| 14 |
method: ece # ece is a custom partition method based on comprehension loss
|
| 15 |
method_params:
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
max_depth: 5 # maximum depth for graph traversal
|
| 21 |
-
max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
|
| 22 |
-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
| 23 |
-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
| 24 |
generate:
|
| 25 |
mode: aggregated # atomic, aggregated, multi_hop, cot
|
| 26 |
data_format: ChatML # Alpaca, Sharegpt, ChatML
|
|
|
|
| 13 |
partition: # graph partition configuration
|
| 14 |
method: ece # ece is a custom partition method based on comprehension loss
|
| 15 |
method_params:
|
| 16 |
+
max_units_per_community: 20 # max nodes and edges per community
|
| 17 |
+
min_units_per_community: 5 # min nodes and edges per community
|
| 18 |
+
max_tokens_per_community: 10240 # max tokens per community
|
| 19 |
+
unit_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
generate:
|
| 21 |
mode: aggregated # atomic, aggregated, multi_hop, cot
|
| 22 |
data_format: ChatML # Alpaca, Sharegpt, ChatML
|
graphgen/configs/atomic_config.yaml
CHANGED
|
@@ -11,16 +11,9 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
|
| 11 |
quiz_samples: 2 # number of quiz samples to generate
|
| 12 |
re_judge: false # whether to re-judge the existing quiz samples
|
| 13 |
partition: # graph partition configuration
|
| 14 |
-
method:
|
| 15 |
method_params:
|
| 16 |
-
|
| 17 |
-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
|
| 18 |
-
expand_method: max_width # expand method, support: max_width, max_depth
|
| 19 |
-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
|
| 20 |
-
max_depth: 3 # maximum depth for graph traversal
|
| 21 |
-
max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
|
| 22 |
-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
| 23 |
-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
| 24 |
generate:
|
| 25 |
mode: atomic # atomic, aggregated, multi_hop, cot
|
| 26 |
data_format: Alpaca # Alpaca, Sharegpt, ChatML
|
|
|
|
| 11 |
quiz_samples: 2 # number of quiz samples to generate
|
| 12 |
re_judge: false # whether to re-judge the existing quiz samples
|
| 13 |
partition: # graph partition configuration
|
| 14 |
+
method: dfs # partition method, support: dfs, bfs, ece, leiden
|
| 15 |
method_params:
|
| 16 |
+
max_units_per_community: 1 # atomic partition, one node or edge per community
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
generate:
|
| 18 |
mode: atomic # atomic, aggregated, multi_hop, cot
|
| 19 |
data_format: Alpaca # Alpaca, Sharegpt, ChatML
|
graphgen/configs/cot_config.yaml
CHANGED
|
@@ -9,11 +9,11 @@ search: # web search configuration
|
|
| 9 |
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
| 10 |
enabled: false
|
| 11 |
partition: # graph partition configuration
|
| 12 |
-
method: leiden # leiden is a
|
| 13 |
method_params:
|
| 14 |
max_size: 20 # Maximum size of communities
|
| 15 |
-
use_lcc: false
|
| 16 |
-
random_seed: 42
|
| 17 |
generate:
|
| 18 |
mode: cot # atomic, aggregated, multi_hop, cot
|
| 19 |
data_format: Sharegpt # Alpaca, Sharegpt, ChatML
|
|
|
|
| 9 |
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
| 10 |
enabled: false
|
| 11 |
partition: # graph partition configuration
|
| 12 |
+
method: leiden # leiden is a partitioner detection algorithm
|
| 13 |
method_params:
|
| 14 |
max_size: 20 # Maximum size of communities
|
| 15 |
+
use_lcc: false # whether to use the largest connected component
|
| 16 |
+
random_seed: 42 # random seed for partitioning
|
| 17 |
generate:
|
| 18 |
mode: cot # atomic, aggregated, multi_hop, cot
|
| 19 |
data_format: Sharegpt # Alpaca, Sharegpt, ChatML
|
graphgen/configs/multi_hop_config.yaml
CHANGED
|
@@ -13,14 +13,10 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
|
|
| 13 |
partition: # graph partition configuration
|
| 14 |
method: ece # ece is a custom partition method based on comprehension loss
|
| 15 |
method_params:
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
max_depth: 1 # maximum depth for graph traversal
|
| 21 |
-
max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
|
| 22 |
-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
|
| 23 |
-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
|
| 24 |
generate:
|
| 25 |
mode: multi_hop # strategy for generating multi-hop QA pairs
|
| 26 |
data_format: ChatML # Alpaca, Sharegpt, ChatML
|
|
|
|
| 13 |
partition: # graph partition configuration
|
| 14 |
method: ece # ece is a custom partition method based on comprehension loss
|
| 15 |
method_params:
|
| 16 |
+
max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3
|
| 17 |
+
min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3
|
| 18 |
+
max_tokens_per_community: 10240 # max tokens per community
|
| 19 |
+
unit_sampling: random # edge sampling strategy, support: random, max_loss, min_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
generate:
|
| 21 |
mode: multi_hop # strategy for generating multi-hop QA pairs
|
| 22 |
data_format: ChatML # Alpaca, Sharegpt, ChatML
|
graphgen/graphgen.py
CHANGED
|
@@ -18,21 +18,14 @@ from graphgen.models import (
|
|
| 18 |
from graphgen.operators import (
|
| 19 |
build_kg,
|
| 20 |
chunk_documents,
|
| 21 |
-
|
| 22 |
judge_statement,
|
|
|
|
| 23 |
quiz,
|
| 24 |
read_files,
|
| 25 |
search_all,
|
| 26 |
-
traverse_graph_for_aggregated,
|
| 27 |
-
traverse_graph_for_atomic,
|
| 28 |
-
traverse_graph_for_multi_hop,
|
| 29 |
-
)
|
| 30 |
-
from graphgen.utils import (
|
| 31 |
-
async_to_sync_method,
|
| 32 |
-
compute_content_hash,
|
| 33 |
-
format_generation_results,
|
| 34 |
-
logger,
|
| 35 |
)
|
|
|
|
| 36 |
|
| 37 |
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 38 |
|
|
@@ -238,51 +231,20 @@ class GraphGen:
|
|
| 238 |
@async_to_sync_method
|
| 239 |
async def generate(self, partition_config: Dict, generate_config: Dict):
|
| 240 |
# Step 1: partition the graph
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
results = await traverse_graph_for_atomic(
|
| 245 |
-
self.synthesizer_llm_client,
|
| 246 |
-
self.tokenizer_instance,
|
| 247 |
-
self.graph_storage,
|
| 248 |
-
partition_config["method_params"],
|
| 249 |
-
self.text_chunks_storage,
|
| 250 |
-
self.progress_bar,
|
| 251 |
-
)
|
| 252 |
-
elif mode == "multi_hop":
|
| 253 |
-
results = await traverse_graph_for_multi_hop(
|
| 254 |
-
self.synthesizer_llm_client,
|
| 255 |
-
self.tokenizer_instance,
|
| 256 |
-
self.graph_storage,
|
| 257 |
-
partition_config["method_params"],
|
| 258 |
-
self.text_chunks_storage,
|
| 259 |
-
self.progress_bar,
|
| 260 |
-
)
|
| 261 |
-
elif mode == "aggregated":
|
| 262 |
-
results = await traverse_graph_for_aggregated(
|
| 263 |
-
self.synthesizer_llm_client,
|
| 264 |
-
self.tokenizer_instance,
|
| 265 |
-
self.graph_storage,
|
| 266 |
-
partition_config["method_params"],
|
| 267 |
-
self.text_chunks_storage,
|
| 268 |
-
self.progress_bar,
|
| 269 |
-
)
|
| 270 |
-
elif mode == "cot":
|
| 271 |
-
results = await generate_cot(
|
| 272 |
-
self.graph_storage,
|
| 273 |
-
self.synthesizer_llm_client,
|
| 274 |
-
method_params=partition_config["method_params"],
|
| 275 |
-
)
|
| 276 |
-
else:
|
| 277 |
-
raise ValueError(f"Unknown generation mode: {mode}")
|
| 278 |
-
# Step 2: generate QA pairs
|
| 279 |
-
# TODO
|
| 280 |
|
| 281 |
-
# Step
|
| 282 |
-
results =
|
| 283 |
-
|
| 284 |
)
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
await self.qa_storage.upsert(results)
|
| 287 |
await self.qa_storage.index_done_callback()
|
| 288 |
|
|
|
|
| 18 |
from graphgen.operators import (
|
| 19 |
build_kg,
|
| 20 |
chunk_documents,
|
| 21 |
+
generate_qas,
|
| 22 |
judge_statement,
|
| 23 |
+
partition_kg,
|
| 24 |
quiz,
|
| 25 |
read_files,
|
| 26 |
search_all,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
)
|
| 28 |
+
from graphgen.utils import async_to_sync_method, compute_content_hash, logger
|
| 29 |
|
| 30 |
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 31 |
|
|
|
|
| 231 |
@async_to_sync_method
|
| 232 |
async def generate(self, partition_config: Dict, generate_config: Dict):
|
| 233 |
# Step 1: partition the graph
|
| 234 |
+
batches = await partition_kg(
|
| 235 |
+
self.graph_storage, self.tokenizer_instance, partition_config
|
| 236 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
+
# Step 2: generate QA pairs
|
| 239 |
+
results = await generate_qas(
|
| 240 |
+
self.synthesizer_llm_client, batches, generate_config
|
| 241 |
)
|
| 242 |
|
| 243 |
+
if not results:
|
| 244 |
+
logger.warning("No QA pairs generated")
|
| 245 |
+
return
|
| 246 |
+
|
| 247 |
+
# Step 3: store the generated QA pairs
|
| 248 |
await self.qa_storage.upsert(results)
|
| 249 |
await self.qa_storage.index_done_callback()
|
| 250 |
|
graphgen/models/__init__.py
CHANGED
|
@@ -1,17 +1,24 @@
|
|
| 1 |
-
from .
|
| 2 |
-
from .
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
| 7 |
from .llm.openai_client import OpenAIClient
|
| 8 |
from .llm.topk_token_model import TopkTokenModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from .reader import CsvReader, JsonlReader, JsonReader, TxtReader
|
| 10 |
from .search.db.uniprot_search import UniProtSearch
|
| 11 |
from .search.kg.wiki_search import WikiSearch
|
| 12 |
from .search.web.bing_search import BingSearch
|
| 13 |
from .search.web.google_search import GoogleSearch
|
| 14 |
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
|
| 15 |
-
from .storage
|
| 16 |
-
from .storage.networkx_storage import NetworkXStorage
|
| 17 |
from .tokenizer import Tokenizer
|
|
|
|
| 1 |
+
from .evaluator import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
|
| 2 |
+
from .generator import (
|
| 3 |
+
AggregatedGenerator,
|
| 4 |
+
AtomicGenerator,
|
| 5 |
+
CoTGenerator,
|
| 6 |
+
MultiHopGenerator,
|
| 7 |
+
)
|
| 8 |
+
from .kg_builder import LightRAGKGBuilder
|
| 9 |
from .llm.openai_client import OpenAIClient
|
| 10 |
from .llm.topk_token_model import TopkTokenModel
|
| 11 |
+
from .partitioner import (
|
| 12 |
+
BFSPartitioner,
|
| 13 |
+
DFSPartitioner,
|
| 14 |
+
ECEPartitioner,
|
| 15 |
+
LeidenPartitioner,
|
| 16 |
+
)
|
| 17 |
from .reader import CsvReader, JsonlReader, JsonReader, TxtReader
|
| 18 |
from .search.db.uniprot_search import UniProtSearch
|
| 19 |
from .search.kg.wiki_search import WikiSearch
|
| 20 |
from .search.web.bing_search import BingSearch
|
| 21 |
from .search.web.google_search import GoogleSearch
|
| 22 |
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
|
| 23 |
+
from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage
|
|
|
|
| 24 |
from .tokenizer import Tokenizer
|
graphgen/models/community/__init__.py
DELETED
|
File without changes
|
graphgen/models/community/community_detector.py
DELETED
|
@@ -1,95 +0,0 @@
|
|
| 1 |
-
from collections import defaultdict
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
from typing import Any, Dict, List
|
| 4 |
-
|
| 5 |
-
from graphgen.models.storage.networkx_storage import NetworkXStorage
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
@dataclass
|
| 9 |
-
class CommunityDetector:
|
| 10 |
-
"""Class for community detection algorithms."""
|
| 11 |
-
|
| 12 |
-
graph_storage: NetworkXStorage = None
|
| 13 |
-
method: str = "leiden"
|
| 14 |
-
method_params: Dict[str, Any] = None
|
| 15 |
-
|
| 16 |
-
async def detect_communities(self) -> Dict[str, int]:
|
| 17 |
-
if self.method == "leiden":
|
| 18 |
-
return await self._leiden_communities(**self.method_params or {})
|
| 19 |
-
raise ValueError(f"Unknown community detection method: {self.method}")
|
| 20 |
-
|
| 21 |
-
async def get_graph(self):
|
| 22 |
-
return await self.graph_storage.get_graph()
|
| 23 |
-
|
| 24 |
-
async def _leiden_communities(
|
| 25 |
-
self, max_size: int = None, **kwargs
|
| 26 |
-
) -> Dict[str, int]:
|
| 27 |
-
"""
|
| 28 |
-
Detect communities using the Leiden algorithm.
|
| 29 |
-
If max_size is given, any community larger than max_size will be split
|
| 30 |
-
into smaller sub-communities each having at most max_size nodes.
|
| 31 |
-
"""
|
| 32 |
-
import igraph as ig
|
| 33 |
-
import networkx as nx
|
| 34 |
-
from leidenalg import ModularityVertexPartition, find_partition
|
| 35 |
-
|
| 36 |
-
graph = await self.get_graph()
|
| 37 |
-
graph.remove_nodes_from(list(nx.isolates(graph)))
|
| 38 |
-
|
| 39 |
-
ig_graph = ig.Graph.TupleList(graph.edges(), directed=False)
|
| 40 |
-
|
| 41 |
-
random_seed = kwargs.get("random_seed", 42)
|
| 42 |
-
use_lcc = kwargs.get("use_lcc", False)
|
| 43 |
-
|
| 44 |
-
communities: Dict[str, int] = {}
|
| 45 |
-
if use_lcc:
|
| 46 |
-
lcc = ig_graph.components().giant()
|
| 47 |
-
partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed)
|
| 48 |
-
for part, cluster in enumerate(partition):
|
| 49 |
-
for v in cluster:
|
| 50 |
-
communities[lcc.vs[v]["name"]] = part
|
| 51 |
-
else:
|
| 52 |
-
offset = 0
|
| 53 |
-
for component in ig_graph.components():
|
| 54 |
-
subgraph = ig_graph.induced_subgraph(component)
|
| 55 |
-
partition = find_partition(
|
| 56 |
-
subgraph, ModularityVertexPartition, seed=random_seed
|
| 57 |
-
)
|
| 58 |
-
for part, cluster in enumerate(partition):
|
| 59 |
-
for v in cluster:
|
| 60 |
-
original_node = subgraph.vs[v]["name"]
|
| 61 |
-
communities[original_node] = part + offset
|
| 62 |
-
offset += len(partition)
|
| 63 |
-
|
| 64 |
-
# split large communities if max_size is specified
|
| 65 |
-
if max_size is None or max_size <= 0:
|
| 66 |
-
return communities
|
| 67 |
-
|
| 68 |
-
return await self._split_communities(communities, max_size)
|
| 69 |
-
|
| 70 |
-
@staticmethod
|
| 71 |
-
async def _split_communities(
|
| 72 |
-
communities: Dict[str, int], max_size: int
|
| 73 |
-
) -> Dict[str, int]:
|
| 74 |
-
"""
|
| 75 |
-
Split communities larger than max_size into smaller sub-communities.
|
| 76 |
-
"""
|
| 77 |
-
cid2nodes: Dict[int, List[str]] = defaultdict(list)
|
| 78 |
-
for node, cid in communities.items():
|
| 79 |
-
cid2nodes[cid].append(node)
|
| 80 |
-
|
| 81 |
-
new_communities: Dict[str, int] = {}
|
| 82 |
-
new_cid = 0
|
| 83 |
-
for cid, nodes in cid2nodes.items():
|
| 84 |
-
if len(nodes) <= max_size:
|
| 85 |
-
for n in nodes:
|
| 86 |
-
new_communities[n] = new_cid
|
| 87 |
-
new_cid += 1
|
| 88 |
-
else:
|
| 89 |
-
for start in range(0, len(nodes), max_size):
|
| 90 |
-
sub = nodes[start : start + max_size]
|
| 91 |
-
for n in sub:
|
| 92 |
-
new_communities[n] = new_cid
|
| 93 |
-
new_cid += 1
|
| 94 |
-
|
| 95 |
-
return new_communities
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/evaluate/__init__.py
DELETED
|
File without changes
|
graphgen/models/evaluator/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .length_evaluator import LengthEvaluator
|
| 2 |
+
from .mtld_evaluator import MTLDEvaluator
|
| 3 |
+
from .reward_evaluator import RewardEvaluator
|
| 4 |
+
from .uni_evaluator import UniEvaluator
|
graphgen/models/{evaluate → evaluator}/base_evaluator.py
RENAMED
|
File without changes
|
graphgen/models/{evaluate → evaluator}/length_evaluator.py
RENAMED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
|
| 3 |
from graphgen.bases.datatypes import QAPair
|
| 4 |
-
from graphgen.models.
|
| 5 |
from graphgen.models.tokenizer import Tokenizer
|
| 6 |
from graphgen.utils import create_event_loop
|
| 7 |
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
|
| 3 |
from graphgen.bases.datatypes import QAPair
|
| 4 |
+
from graphgen.models.evaluator.base_evaluator import BaseEvaluator
|
| 5 |
from graphgen.models.tokenizer import Tokenizer
|
| 6 |
from graphgen.utils import create_event_loop
|
| 7 |
|
graphgen/models/{evaluate → evaluator}/mtld_evaluator.py
RENAMED
|
@@ -2,7 +2,7 @@ from dataclasses import dataclass, field
|
|
| 2 |
from typing import Set
|
| 3 |
|
| 4 |
from graphgen.bases.datatypes import QAPair
|
| 5 |
-
from graphgen.models.
|
| 6 |
from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language
|
| 7 |
|
| 8 |
nltk_helper = NLTKHelper()
|
|
|
|
| 2 |
from typing import Set
|
| 3 |
|
| 4 |
from graphgen.bases.datatypes import QAPair
|
| 5 |
+
from graphgen.models.evaluator.base_evaluator import BaseEvaluator
|
| 6 |
from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language
|
| 7 |
|
| 8 |
nltk_helper = NLTKHelper()
|
graphgen/models/{evaluate → evaluator}/reward_evaluator.py
RENAMED
|
File without changes
|
graphgen/models/{evaluate → evaluator}/uni_evaluator.py
RENAMED
|
File without changes
|
graphgen/models/generator/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .aggregated_generator import AggregatedGenerator
|
| 2 |
+
from .atomic_generator import AtomicGenerator
|
| 3 |
+
from .cot_generator import CoTGenerator
|
| 4 |
+
from .multi_hop_generator import MultiHopGenerator
|
graphgen/models/generator/aggregated_generator.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from graphgen.bases import BaseGenerator
|
| 5 |
+
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
|
| 6 |
+
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class AggregatedGenerator(BaseGenerator):
|
| 11 |
+
"""
|
| 12 |
+
Aggregated Generator follows a TWO-STEP process:
|
| 13 |
+
1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning.
|
| 14 |
+
The rephrased text is considered as answer to be used in the next step.
|
| 15 |
+
2. question generation: Generate relevant questions based on the rephrased text.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def build_prompt(
|
| 20 |
+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
|
| 21 |
+
) -> str:
|
| 22 |
+
"""
|
| 23 |
+
Build prompts for REPHRASE.
|
| 24 |
+
:param batch
|
| 25 |
+
:return:
|
| 26 |
+
"""
|
| 27 |
+
nodes, edges = batch
|
| 28 |
+
entities_str = "\n".join(
|
| 29 |
+
[
|
| 30 |
+
f"{index + 1}. {node[0]}: {node[1]['description']}"
|
| 31 |
+
for index, node in enumerate(nodes)
|
| 32 |
+
]
|
| 33 |
+
)
|
| 34 |
+
relations_str = "\n".join(
|
| 35 |
+
[
|
| 36 |
+
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
|
| 37 |
+
for index, edge in enumerate(edges)
|
| 38 |
+
]
|
| 39 |
+
)
|
| 40 |
+
language = detect_main_language(entities_str + relations_str)
|
| 41 |
+
|
| 42 |
+
# TODO: configure add_context
|
| 43 |
+
# if add_context:
|
| 44 |
+
# original_ids = [
|
| 45 |
+
# node["source_id"].split("<SEP>")[0] for node in _process_nodes
|
| 46 |
+
# ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
|
| 47 |
+
# original_ids = list(set(original_ids))
|
| 48 |
+
# original_text = await text_chunks_storage.get_by_ids(original_ids)
|
| 49 |
+
# original_text = "\n".join(
|
| 50 |
+
# [
|
| 51 |
+
# f"{index + 1}. {text['content']}"
|
| 52 |
+
# for index, text in enumerate(original_text)
|
| 53 |
+
# ]
|
| 54 |
+
# )
|
| 55 |
+
prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
|
| 56 |
+
language=language, entities=entities_str, relationships=relations_str
|
| 57 |
+
)
|
| 58 |
+
return prompt
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def parse_rephrased_text(response: str) -> str:
|
| 62 |
+
"""
|
| 63 |
+
Parse the rephrased text from the response.
|
| 64 |
+
:param response:
|
| 65 |
+
:return: rephrased text
|
| 66 |
+
"""
|
| 67 |
+
if "Rephrased Text:" in response:
|
| 68 |
+
rephrased_text = response.split("Rephrased Text:")[1].strip()
|
| 69 |
+
elif "重述文本:" in response:
|
| 70 |
+
rephrased_text = response.split("重述文本:")[1].strip()
|
| 71 |
+
else:
|
| 72 |
+
rephrased_text = response.strip()
|
| 73 |
+
return rephrased_text.strip('"')
|
| 74 |
+
|
| 75 |
+
@staticmethod
|
| 76 |
+
def _build_prompt_for_question_generation(answer: str) -> str:
|
| 77 |
+
"""
|
| 78 |
+
Build prompts for QUESTION GENERATION.
|
| 79 |
+
:param answer:
|
| 80 |
+
:return:
|
| 81 |
+
"""
|
| 82 |
+
language = detect_main_language(answer)
|
| 83 |
+
prompt = AGGREGATED_GENERATION_PROMPT[language]["QUESTION_GENERATION"].format(
|
| 84 |
+
answer=answer
|
| 85 |
+
)
|
| 86 |
+
return prompt
|
| 87 |
+
|
| 88 |
+
@staticmethod
|
| 89 |
+
def parse_response(response: str) -> dict:
|
| 90 |
+
if response.startswith("Question:"):
|
| 91 |
+
question = response[len("Question:") :].strip()
|
| 92 |
+
elif response.startswith("问题:"):
|
| 93 |
+
question = response[len("问题:") :].strip()
|
| 94 |
+
else:
|
| 95 |
+
question = response.strip()
|
| 96 |
+
return {
|
| 97 |
+
"question": question,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
async def generate(
|
| 101 |
+
self,
|
| 102 |
+
batch: tuple[
|
| 103 |
+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 104 |
+
],
|
| 105 |
+
) -> dict[str, Any]:
|
| 106 |
+
"""
|
| 107 |
+
Generate QAs based on a given batch.
|
| 108 |
+
:param batch
|
| 109 |
+
:return: QA pairs
|
| 110 |
+
"""
|
| 111 |
+
result = {}
|
| 112 |
+
rephrasing_prompt = self.build_prompt(batch)
|
| 113 |
+
response = await self.llm_client.generate_answer(rephrasing_prompt)
|
| 114 |
+
context = self.parse_rephrased_text(response)
|
| 115 |
+
question_generation_prompt = self._build_prompt_for_question_generation(context)
|
| 116 |
+
response = await self.llm_client.generate_answer(question_generation_prompt)
|
| 117 |
+
question = self.parse_response(response)["question"]
|
| 118 |
+
logger.info("Question: %s", question)
|
| 119 |
+
logger.info("Answer: %s", context)
|
| 120 |
+
qa_pairs = {
|
| 121 |
+
compute_content_hash(question): {
|
| 122 |
+
"question": question,
|
| 123 |
+
"answer": context,
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
result.update(qa_pairs)
|
| 127 |
+
return result
|
graphgen/models/generator/atomic_generator.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from graphgen.bases import BaseGenerator
|
| 5 |
+
from graphgen.templates import ATOMIC_GENERATION_PROMPT
|
| 6 |
+
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class AtomicGenerator(BaseGenerator):
|
| 11 |
+
@staticmethod
|
| 12 |
+
def build_prompt(
|
| 13 |
+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
|
| 14 |
+
) -> str:
|
| 15 |
+
nodes, edges = batch
|
| 16 |
+
context = ""
|
| 17 |
+
for node in nodes:
|
| 18 |
+
context += f"- {node[0]}: {node[1]['description']}\n"
|
| 19 |
+
for edge in edges:
|
| 20 |
+
context += f"- {edge[0]} - {edge[1]}: {edge[2]['description']}\n"
|
| 21 |
+
language = detect_main_language(context)
|
| 22 |
+
|
| 23 |
+
prompt = ATOMIC_GENERATION_PROMPT[language].format(context=context)
|
| 24 |
+
return prompt
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def parse_response(response: str) -> dict:
|
| 28 |
+
"""
|
| 29 |
+
AtomicGenerator normally generates one QA pair per response.
|
| 30 |
+
So we just need to parse one QA pair from the response.
|
| 31 |
+
:param response:
|
| 32 |
+
:return:
|
| 33 |
+
"""
|
| 34 |
+
if "Question:" in response and "Answer:" in response:
|
| 35 |
+
question = response.split("Question:")[1].split("Answer:")[0].strip()
|
| 36 |
+
answer = response.split("Answer:")[1].strip()
|
| 37 |
+
elif "问题:" in response and "答案:" in response:
|
| 38 |
+
question = response.split("问题:")[1].split("答案:")[0].strip()
|
| 39 |
+
answer = response.split("答案:")[1].strip()
|
| 40 |
+
else:
|
| 41 |
+
logger.warning("Failed to parse response: %s", response)
|
| 42 |
+
return {}
|
| 43 |
+
question = question.strip('"')
|
| 44 |
+
answer = answer.strip('"')
|
| 45 |
+
logger.info("Question: %s", question)
|
| 46 |
+
logger.info("Answer: %s", answer)
|
| 47 |
+
return {
|
| 48 |
+
compute_content_hash(question): {
|
| 49 |
+
"question": question,
|
| 50 |
+
"answer": answer,
|
| 51 |
+
}
|
| 52 |
+
}
|
graphgen/models/generator/cot_generator.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from graphgen.bases import BaseGenerator
|
| 5 |
+
from graphgen.templates import COT_GENERATION_PROMPT
|
| 6 |
+
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class CoTGenerator(BaseGenerator):
|
| 11 |
+
@staticmethod
|
| 12 |
+
def build_prompt(
|
| 13 |
+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
|
| 14 |
+
) -> str:
|
| 15 |
+
"""
|
| 16 |
+
Build prompts for COT Template Design.
|
| 17 |
+
:param batch:
|
| 18 |
+
:return:
|
| 19 |
+
"""
|
| 20 |
+
nodes, edges = batch
|
| 21 |
+
entities_str = "\n".join(
|
| 22 |
+
[
|
| 23 |
+
f"{index + 1}. {node[0]}: {node[1]['description']}"
|
| 24 |
+
for index, node in enumerate(nodes)
|
| 25 |
+
]
|
| 26 |
+
)
|
| 27 |
+
relationships_str = "\n".join(
|
| 28 |
+
[
|
| 29 |
+
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
|
| 30 |
+
for index, edge in enumerate(edges)
|
| 31 |
+
]
|
| 32 |
+
)
|
| 33 |
+
language = detect_main_language(entities_str + relationships_str)
|
| 34 |
+
prompt = COT_GENERATION_PROMPT[language]["COT_TEMPLATE_DESIGN"].format(
|
| 35 |
+
entities=entities_str, relationships=relationships_str
|
| 36 |
+
)
|
| 37 |
+
return prompt
|
| 38 |
+
|
| 39 |
+
@staticmethod
|
| 40 |
+
def build_prompt_for_cot_generation(
|
| 41 |
+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]],
|
| 42 |
+
question: str,
|
| 43 |
+
reasoning_path: str,
|
| 44 |
+
) -> str:
|
| 45 |
+
"""
|
| 46 |
+
Build prompts for COT Generation.
|
| 47 |
+
"""
|
| 48 |
+
nodes, edges = batch
|
| 49 |
+
entities_str = "\n".join(
|
| 50 |
+
[
|
| 51 |
+
f"{index + 1}. {node[0]}: {node[1]['description']}"
|
| 52 |
+
for index, node in enumerate(nodes)
|
| 53 |
+
]
|
| 54 |
+
)
|
| 55 |
+
relationships_str = "\n".join(
|
| 56 |
+
[
|
| 57 |
+
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
|
| 58 |
+
for index, edge in enumerate(edges)
|
| 59 |
+
]
|
| 60 |
+
)
|
| 61 |
+
language = detect_main_language(entities_str + relationships_str)
|
| 62 |
+
prompt = COT_GENERATION_PROMPT[language]["COT_GENERATION"].format(
|
| 63 |
+
entities=entities_str,
|
| 64 |
+
relationships=relationships_str,
|
| 65 |
+
question=question,
|
| 66 |
+
reasoning_template=reasoning_path,
|
| 67 |
+
)
|
| 68 |
+
return prompt
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def parse_response(response: str) -> dict:
|
| 72 |
+
if "Question:" in response and "Reasoning-Path Design:" in response:
|
| 73 |
+
question = (
|
| 74 |
+
response.split("Question:")[1]
|
| 75 |
+
.split("Reasoning-Path Design:")[0]
|
| 76 |
+
.strip()
|
| 77 |
+
)
|
| 78 |
+
reasoning_path = response.split("Reasoning-Path Design:")[1].strip()
|
| 79 |
+
elif "问题:" in response and "推理路径设计:" in response:
|
| 80 |
+
question = response.split("问题:")[1].split("推理路径设计:")[0].strip()
|
| 81 |
+
reasoning_path = response.split("推理路径设计:")[1].strip()
|
| 82 |
+
else:
|
| 83 |
+
logger.warning("Failed to parse CoT template: %s", response)
|
| 84 |
+
return {}
|
| 85 |
+
|
| 86 |
+
question = question.strip('"')
|
| 87 |
+
reasoning_path = reasoning_path.strip('"')
|
| 88 |
+
logger.info("CoT Question: %s", question)
|
| 89 |
+
logger.info("CoT Reasoning Path: %s", reasoning_path)
|
| 90 |
+
return {
|
| 91 |
+
"question": question,
|
| 92 |
+
"reasoning_path": reasoning_path,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
async def generate(
|
| 96 |
+
self,
|
| 97 |
+
batch: tuple[
|
| 98 |
+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 99 |
+
],
|
| 100 |
+
) -> dict[str, Any]:
|
| 101 |
+
"""
|
| 102 |
+
Generate QAs based on a given batch.
|
| 103 |
+
:param batch
|
| 104 |
+
:return: QA pairs
|
| 105 |
+
"""
|
| 106 |
+
result = {}
|
| 107 |
+
prompt = self.build_prompt(batch)
|
| 108 |
+
response = await self.llm_client.generate_answer(prompt)
|
| 109 |
+
response = self.parse_response(response)
|
| 110 |
+
question, reasoning_path = response["question"], response["reasoning_path"]
|
| 111 |
+
prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
|
| 112 |
+
cot_answer = await self.llm_client.generate_answer(prompt)
|
| 113 |
+
logger.info("CoT Answer: %s", cot_answer)
|
| 114 |
+
qa_pairs = {
|
| 115 |
+
compute_content_hash(question): {
|
| 116 |
+
"question": question,
|
| 117 |
+
"answer": cot_answer,
|
| 118 |
+
"reasoning_path": reasoning_path,
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
result.update(qa_pairs)
|
| 122 |
+
return result
|
graphgen/models/generator/multi_hop_generator.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from graphgen.bases import BaseGenerator
|
| 5 |
+
from graphgen.templates import MULTI_HOP_GENERATION_PROMPT
|
| 6 |
+
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class MultiHopGenerator(BaseGenerator):
|
| 11 |
+
@staticmethod
|
| 12 |
+
def build_prompt(
|
| 13 |
+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
|
| 14 |
+
) -> str:
|
| 15 |
+
nodes, edges = batch
|
| 16 |
+
entities_str = "\n".join(
|
| 17 |
+
[
|
| 18 |
+
f"{index + 1}. {node[0]}: {node[1]['description']}"
|
| 19 |
+
for index, node in enumerate(nodes)
|
| 20 |
+
]
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
relationships_str = "\n".join(
|
| 24 |
+
[
|
| 25 |
+
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
|
| 26 |
+
for index, edge in enumerate(edges)
|
| 27 |
+
]
|
| 28 |
+
)
|
| 29 |
+
language = detect_main_language(entities_str + relationships_str)
|
| 30 |
+
prompt = MULTI_HOP_GENERATION_PROMPT[language].format(
|
| 31 |
+
entities=entities_str, relationships=relationships_str
|
| 32 |
+
)
|
| 33 |
+
return prompt
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def parse_response(response: str) -> dict:
|
| 37 |
+
if "Question:" in response and "Answer:" in response:
|
| 38 |
+
question = response.split("Question:")[1].split("Answer:")[0].strip()
|
| 39 |
+
answer = response.split("Answer:")[1].strip()
|
| 40 |
+
elif "问题:" in response and "答案:" in response:
|
| 41 |
+
question = response.split("问题:")[1].split("答案:")[0].strip()
|
| 42 |
+
answer = response.split("答案:")[1].strip()
|
| 43 |
+
else:
|
| 44 |
+
logger.warning("Failed to parse response: %s", response)
|
| 45 |
+
return {}
|
| 46 |
+
question = question.strip('"')
|
| 47 |
+
answer = answer.strip('"')
|
| 48 |
+
logger.info("Question: %s", question)
|
| 49 |
+
logger.info("Answer: %s", answer)
|
| 50 |
+
return {
|
| 51 |
+
compute_content_hash(question): {
|
| 52 |
+
"question": question,
|
| 53 |
+
"answer": answer,
|
| 54 |
+
}
|
| 55 |
+
}
|
graphgen/models/kg_builder/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .light_rag_kg_builder import LightRAGKGBuilder
|
graphgen/models/llm/limitter.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
|
|
|
| 1 |
import time
|
| 2 |
from datetime import datetime, timedelta
|
| 3 |
-
import asyncio
|
| 4 |
|
| 5 |
from graphgen.utils import logger
|
| 6 |
|
| 7 |
|
| 8 |
class RPM:
|
| 9 |
-
|
| 10 |
def __init__(self, rpm: int = 1000):
|
| 11 |
self.rpm = rpm
|
| 12 |
-
self.record = {
|
| 13 |
|
| 14 |
-
|
|
|
|
| 15 |
current_time = time.time()
|
| 16 |
dt_object = datetime.fromtimestamp(current_time)
|
| 17 |
total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
|
|
@@ -22,37 +22,35 @@ class RPM:
|
|
| 22 |
dt_object = datetime.fromtimestamp(current)
|
| 23 |
minute_slot = self.get_minute_slot()
|
| 24 |
|
| 25 |
-
if self.record[
|
| 26 |
# check RPM exceed
|
| 27 |
-
if self.record[
|
| 28 |
# wait until next minute
|
| 29 |
-
next_minute = dt_object.replace(
|
| 30 |
-
|
|
|
|
| 31 |
_next = next_minute.timestamp()
|
| 32 |
sleep_time = abs(_next - current)
|
| 33 |
if not silent:
|
| 34 |
-
logger.info(
|
| 35 |
await asyncio.sleep(sleep_time)
|
| 36 |
|
| 37 |
-
self.record = {
|
| 38 |
-
'rpm_slot': self.get_minute_slot(),
|
| 39 |
-
'counter': 0
|
| 40 |
-
}
|
| 41 |
else:
|
| 42 |
-
self.record = {
|
| 43 |
-
self.record[
|
| 44 |
|
| 45 |
if not silent:
|
| 46 |
logger.debug(self.record)
|
| 47 |
|
| 48 |
|
| 49 |
class TPM:
|
| 50 |
-
|
| 51 |
def __init__(self, tpm: int = 20000):
|
| 52 |
self.tpm = tpm
|
| 53 |
-
self.record = {
|
| 54 |
|
| 55 |
-
|
|
|
|
| 56 |
current_time = time.time()
|
| 57 |
dt_object = datetime.fromtimestamp(current_time)
|
| 58 |
total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
|
|
@@ -64,25 +62,25 @@ class TPM:
|
|
| 64 |
minute_slot = self.get_minute_slot()
|
| 65 |
|
| 66 |
# get next slot, skip
|
| 67 |
-
if self.record[
|
| 68 |
-
self.record = {
|
| 69 |
return
|
| 70 |
|
| 71 |
# check RPM exceed
|
| 72 |
-
self.record[
|
| 73 |
-
|
|
|
|
|
|
|
| 74 |
# wait until next minute
|
| 75 |
-
next_minute = dt_object.replace(
|
| 76 |
-
|
|
|
|
| 77 |
_next = next_minute.timestamp()
|
| 78 |
sleep_time = abs(_next - current)
|
| 79 |
-
logger.
|
| 80 |
await asyncio.sleep(sleep_time)
|
| 81 |
|
| 82 |
-
self.record = {
|
| 83 |
-
'tpm_slot': self.get_minute_slot(),
|
| 84 |
-
'counter': token_count
|
| 85 |
-
}
|
| 86 |
|
| 87 |
if not silent:
|
| 88 |
logger.debug(self.record)
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
import time
|
| 3 |
from datetime import datetime, timedelta
|
|
|
|
| 4 |
|
| 5 |
from graphgen.utils import logger
|
| 6 |
|
| 7 |
|
| 8 |
class RPM:
|
|
|
|
| 9 |
def __init__(self, rpm: int = 1000):
|
| 10 |
self.rpm = rpm
|
| 11 |
+
self.record = {"rpm_slot": self.get_minute_slot(), "counter": 0}
|
| 12 |
|
| 13 |
+
@staticmethod
|
| 14 |
+
def get_minute_slot():
|
| 15 |
current_time = time.time()
|
| 16 |
dt_object = datetime.fromtimestamp(current_time)
|
| 17 |
total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
|
|
|
|
| 22 |
dt_object = datetime.fromtimestamp(current)
|
| 23 |
minute_slot = self.get_minute_slot()
|
| 24 |
|
| 25 |
+
if self.record["rpm_slot"] == minute_slot:
|
| 26 |
# check RPM exceed
|
| 27 |
+
if self.record["counter"] >= self.rpm:
|
| 28 |
# wait until next minute
|
| 29 |
+
next_minute = dt_object.replace(second=0, microsecond=0) + timedelta(
|
| 30 |
+
minutes=1
|
| 31 |
+
)
|
| 32 |
_next = next_minute.timestamp()
|
| 33 |
sleep_time = abs(_next - current)
|
| 34 |
if not silent:
|
| 35 |
+
logger.info("RPM sleep %s", sleep_time)
|
| 36 |
await asyncio.sleep(sleep_time)
|
| 37 |
|
| 38 |
+
self.record = {"rpm_slot": self.get_minute_slot(), "counter": 0}
|
|
|
|
|
|
|
|
|
|
| 39 |
else:
|
| 40 |
+
self.record = {"rpm_slot": self.get_minute_slot(), "counter": 0}
|
| 41 |
+
self.record["counter"] += 1
|
| 42 |
|
| 43 |
if not silent:
|
| 44 |
logger.debug(self.record)
|
| 45 |
|
| 46 |
|
| 47 |
class TPM:
|
|
|
|
| 48 |
def __init__(self, tpm: int = 20000):
|
| 49 |
self.tpm = tpm
|
| 50 |
+
self.record = {"tpm_slot": self.get_minute_slot(), "counter": 0}
|
| 51 |
|
| 52 |
+
@staticmethod
|
| 53 |
+
def get_minute_slot():
|
| 54 |
current_time = time.time()
|
| 55 |
dt_object = datetime.fromtimestamp(current_time)
|
| 56 |
total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
|
|
|
|
| 62 |
minute_slot = self.get_minute_slot()
|
| 63 |
|
| 64 |
# get next slot, skip
|
| 65 |
+
if self.record["tpm_slot"] != minute_slot:
|
| 66 |
+
self.record = {"tpm_slot": minute_slot, "counter": token_count}
|
| 67 |
return
|
| 68 |
|
| 69 |
# check RPM exceed
|
| 70 |
+
old_counter = self.record["counter"]
|
| 71 |
+
self.record["counter"] += token_count
|
| 72 |
+
if self.record["counter"] > self.tpm:
|
| 73 |
+
logger.info("Current TPM: %s, limit: %s", old_counter, self.tpm)
|
| 74 |
# wait until next minute
|
| 75 |
+
next_minute = dt_object.replace(second=0, microsecond=0) + timedelta(
|
| 76 |
+
minutes=1
|
| 77 |
+
)
|
| 78 |
_next = next_minute.timestamp()
|
| 79 |
sleep_time = abs(_next - current)
|
| 80 |
+
logger.warning("TPM limit exceeded, wait %s seconds", sleep_time)
|
| 81 |
await asyncio.sleep(sleep_time)
|
| 82 |
|
| 83 |
+
self.record = {"tpm_slot": self.get_minute_slot(), "counter": token_count}
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
if not silent:
|
| 86 |
logger.debug(self.record)
|
graphgen/models/llm/openai_client.py
CHANGED
|
@@ -39,6 +39,8 @@ class OpenAIClient(BaseLLMClient):
|
|
| 39 |
seed: Optional[int] = None,
|
| 40 |
topk_per_token: int = 5, # number of topk tokens to generate for each token
|
| 41 |
request_limit: bool = False,
|
|
|
|
|
|
|
| 42 |
**kwargs: Any,
|
| 43 |
):
|
| 44 |
super().__init__(**kwargs)
|
|
@@ -51,8 +53,8 @@ class OpenAIClient(BaseLLMClient):
|
|
| 51 |
|
| 52 |
self.token_usage: list = []
|
| 53 |
self.request_limit = request_limit
|
| 54 |
-
self.rpm = RPM(
|
| 55 |
-
self.tpm = TPM(
|
| 56 |
|
| 57 |
self.__post_init__()
|
| 58 |
|
|
|
|
| 39 |
seed: Optional[int] = None,
|
| 40 |
topk_per_token: int = 5, # number of topk tokens to generate for each token
|
| 41 |
request_limit: bool = False,
|
| 42 |
+
rpm: Optional[RPM] = None,
|
| 43 |
+
tpm: Optional[TPM] = None,
|
| 44 |
**kwargs: Any,
|
| 45 |
):
|
| 46 |
super().__init__(**kwargs)
|
|
|
|
| 53 |
|
| 54 |
self.token_usage: list = []
|
| 55 |
self.request_limit = request_limit
|
| 56 |
+
self.rpm = rpm or RPM()
|
| 57 |
+
self.tpm = tpm or TPM()
|
| 58 |
|
| 59 |
self.__post_init__()
|
| 60 |
|
graphgen/models/partitioner/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .bfs_partitioner import BFSPartitioner
|
| 2 |
+
from .dfs_partitioner import DFSPartitioner
|
| 3 |
+
from .ece_partitioner import ECEPartitioner
|
| 4 |
+
from .leiden_partitioner import LeidenPartitioner
|
graphgen/models/partitioner/bfs_partitioner.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from collections import deque
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, List
|
| 5 |
+
|
| 6 |
+
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
| 7 |
+
from graphgen.bases.datatypes import Community
|
| 8 |
+
|
| 9 |
+
NODE_UNIT: str = "n"
|
| 10 |
+
EDGE_UNIT: str = "e"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class BFSPartitioner(BasePartitioner):
|
| 15 |
+
"""
|
| 16 |
+
BFS partitioner that partitions the graph into communities of a fixed size.
|
| 17 |
+
1. Randomly choose a unit.
|
| 18 |
+
2. Expand the community using BFS until the max unit size is reached.
|
| 19 |
+
(A unit is a node or an edge.)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
async def partition(
|
| 23 |
+
self,
|
| 24 |
+
g: BaseGraphStorage,
|
| 25 |
+
max_units_per_community: int = 1,
|
| 26 |
+
**kwargs: Any,
|
| 27 |
+
) -> List[Community]:
|
| 28 |
+
nodes = await g.get_all_nodes()
|
| 29 |
+
edges = await g.get_all_edges()
|
| 30 |
+
|
| 31 |
+
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 32 |
+
|
| 33 |
+
used_n: set[str] = set()
|
| 34 |
+
used_e: set[frozenset[str]] = set()
|
| 35 |
+
communities: List[Community] = []
|
| 36 |
+
|
| 37 |
+
units = [(NODE_UNIT, n[0]) for n in nodes] + [
|
| 38 |
+
(EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
|
| 39 |
+
]
|
| 40 |
+
random.shuffle(units)
|
| 41 |
+
|
| 42 |
+
for kind, seed in units:
|
| 43 |
+
if (kind == NODE_UNIT and seed in used_n) or (
|
| 44 |
+
kind == EDGE_UNIT and seed in used_e
|
| 45 |
+
):
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
comm_n: List[str] = []
|
| 49 |
+
comm_e: List[tuple[str, str]] = []
|
| 50 |
+
queue: deque[tuple[str, Any]] = deque([(kind, seed)])
|
| 51 |
+
cnt = 0
|
| 52 |
+
|
| 53 |
+
while queue and cnt < max_units_per_community:
|
| 54 |
+
k, it = queue.popleft()
|
| 55 |
+
if k == NODE_UNIT:
|
| 56 |
+
if it in used_n:
|
| 57 |
+
continue
|
| 58 |
+
used_n.add(it)
|
| 59 |
+
comm_n.append(it)
|
| 60 |
+
cnt += 1
|
| 61 |
+
for nei in adj[it]:
|
| 62 |
+
e_key = frozenset((it, nei))
|
| 63 |
+
if e_key not in used_e:
|
| 64 |
+
queue.append((EDGE_UNIT, e_key))
|
| 65 |
+
else:
|
| 66 |
+
if it in used_e:
|
| 67 |
+
continue
|
| 68 |
+
used_e.add(it)
|
| 69 |
+
|
| 70 |
+
u, v = it
|
| 71 |
+
comm_e.append((u, v))
|
| 72 |
+
cnt += 1
|
| 73 |
+
# push nodes that are not visited
|
| 74 |
+
for n in it:
|
| 75 |
+
if n not in used_n:
|
| 76 |
+
queue.append((NODE_UNIT, n))
|
| 77 |
+
|
| 78 |
+
if comm_n or comm_e:
|
| 79 |
+
communities.append(
|
| 80 |
+
Community(id=len(communities), nodes=comm_n, edges=comm_e)
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return communities
|
graphgen/models/partitioner/dfs_partitioner.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, List
|
| 4 |
+
|
| 5 |
+
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
| 6 |
+
from graphgen.bases.datatypes import Community
|
| 7 |
+
|
| 8 |
+
NODE_UNIT: str = "n"
|
| 9 |
+
EDGE_UNIT: str = "e"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class DFSPartitioner(BasePartitioner):
|
| 14 |
+
"""
|
| 15 |
+
DFS partitioner that partitions the graph into communities of a fixed size.
|
| 16 |
+
1. Randomly choose a unit.
|
| 17 |
+
2. Random walk using DFS until the community reaches the max unit size.
|
| 18 |
+
(In GraphGen, a unit is defined as a node or an edge.)
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
async def partition(
|
| 22 |
+
self,
|
| 23 |
+
g: BaseGraphStorage,
|
| 24 |
+
max_units_per_community: int = 1,
|
| 25 |
+
**kwargs: Any,
|
| 26 |
+
) -> List[Community]:
|
| 27 |
+
nodes = await g.get_all_nodes()
|
| 28 |
+
edges = await g.get_all_edges()
|
| 29 |
+
|
| 30 |
+
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 31 |
+
|
| 32 |
+
used_n: set[str] = set()
|
| 33 |
+
used_e: set[frozenset[str]] = set()
|
| 34 |
+
communities: List[Community] = []
|
| 35 |
+
|
| 36 |
+
units = [(NODE_UNIT, n[0]) for n in nodes] + [
|
| 37 |
+
(EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
|
| 38 |
+
]
|
| 39 |
+
random.shuffle(units)
|
| 40 |
+
|
| 41 |
+
for kind, seed in units:
|
| 42 |
+
if (kind == NODE_UNIT and seed in used_n) or (
|
| 43 |
+
kind == EDGE_UNIT and seed in used_e
|
| 44 |
+
):
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
comm_n, comm_e = [], []
|
| 48 |
+
stack = [(kind, seed)]
|
| 49 |
+
cnt = 0
|
| 50 |
+
|
| 51 |
+
while stack and cnt < max_units_per_community:
|
| 52 |
+
k, it = stack.pop()
|
| 53 |
+
if k == NODE_UNIT:
|
| 54 |
+
if it in used_n:
|
| 55 |
+
continue
|
| 56 |
+
used_n.add(it)
|
| 57 |
+
comm_n.append(it)
|
| 58 |
+
cnt += 1
|
| 59 |
+
for nei in adj[it]:
|
| 60 |
+
e_key = frozenset((it, nei))
|
| 61 |
+
if e_key not in used_e:
|
| 62 |
+
stack.append((EDGE_UNIT, e_key))
|
| 63 |
+
break
|
| 64 |
+
else:
|
| 65 |
+
if it in used_e:
|
| 66 |
+
continue
|
| 67 |
+
used_e.add(it)
|
| 68 |
+
comm_e.append(tuple(it))
|
| 69 |
+
cnt += 1
|
| 70 |
+
# push neighboring nodes
|
| 71 |
+
for n in it:
|
| 72 |
+
if n not in used_n:
|
| 73 |
+
stack.append((NODE_UNIT, n))
|
| 74 |
+
|
| 75 |
+
if comm_n or comm_e:
|
| 76 |
+
communities.append(
|
| 77 |
+
Community(id=len(communities), nodes=comm_n, edges=comm_e)
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
return communities
|
graphgen/models/partitioner/ece_partitioner.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import random
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
| 5 |
+
|
| 6 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
| 7 |
+
|
| 8 |
+
from graphgen.bases import BaseGraphStorage
|
| 9 |
+
from graphgen.bases.datatypes import Community
|
| 10 |
+
from graphgen.models.partitioner.bfs_partitioner import BFSPartitioner
|
| 11 |
+
|
| 12 |
+
NODE_UNIT: str = "n"
|
| 13 |
+
EDGE_UNIT: str = "e"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class ECEPartitioner(BFSPartitioner):
|
| 18 |
+
"""
|
| 19 |
+
ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE).
|
| 20 |
+
We calculate ECE for edges in KG (represented as 'comprehension loss')
|
| 21 |
+
and group edges with similar ECE values into the same community.
|
| 22 |
+
1. Select a sampling strategy.
|
| 23 |
+
2. Choose a unit based on the sampling strategy.
|
| 24 |
+
2. Expand the community using BFS.
|
| 25 |
+
3. When expending, prefer to add units with the sampling strategy.
|
| 26 |
+
4. Stop when the max unit size is reached or the max input length is reached.
|
| 27 |
+
(A unit is a node or an edge.)
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
def _sort_units(units: list, edge_sampling: str) -> list:
|
| 32 |
+
"""
|
| 33 |
+
Sort units with edge sampling strategy
|
| 34 |
+
|
| 35 |
+
:param units: total units
|
| 36 |
+
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
|
| 37 |
+
:return: sorted units
|
| 38 |
+
"""
|
| 39 |
+
if edge_sampling == "random":
|
| 40 |
+
random.shuffle(units)
|
| 41 |
+
elif edge_sampling == "min_loss":
|
| 42 |
+
units = sorted(
|
| 43 |
+
units,
|
| 44 |
+
key=lambda x: x[-1]["loss"],
|
| 45 |
+
)
|
| 46 |
+
elif edge_sampling == "max_loss":
|
| 47 |
+
units = sorted(
|
| 48 |
+
units,
|
| 49 |
+
key=lambda x: x[-1]["loss"],
|
| 50 |
+
reverse=True,
|
| 51 |
+
)
|
| 52 |
+
else:
|
| 53 |
+
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
|
| 54 |
+
return units
|
| 55 |
+
|
| 56 |
+
async def partition(
|
| 57 |
+
self,
|
| 58 |
+
g: BaseGraphStorage,
|
| 59 |
+
max_units_per_community: int = 10,
|
| 60 |
+
min_units_per_community: int = 1,
|
| 61 |
+
max_tokens_per_community: int = 10240,
|
| 62 |
+
unit_sampling: str = "random",
|
| 63 |
+
**kwargs: Any,
|
| 64 |
+
) -> List[Community]:
|
| 65 |
+
nodes: List[Tuple[str, dict]] = await g.get_all_nodes()
|
| 66 |
+
edges: List[Tuple[str, str, dict]] = await g.get_all_edges()
|
| 67 |
+
|
| 68 |
+
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 69 |
+
node_dict = dict(nodes)
|
| 70 |
+
edge_dict = {frozenset((u, v)): d for u, v, d in edges}
|
| 71 |
+
|
| 72 |
+
all_units: List[Tuple[str, Any, dict]] = [
|
| 73 |
+
(NODE_UNIT, nid, d) for nid, d in nodes
|
| 74 |
+
] + [(EDGE_UNIT, frozenset((u, v)), d) for u, v, d in edges]
|
| 75 |
+
|
| 76 |
+
used_n: Set[str] = set()
|
| 77 |
+
used_e: Set[frozenset[str]] = set()
|
| 78 |
+
communities: List = []
|
| 79 |
+
|
| 80 |
+
all_units = self._sort_units(all_units, unit_sampling)
|
| 81 |
+
|
| 82 |
+
async def _grow_community(
|
| 83 |
+
seed_unit: Tuple[str, Any, dict]
|
| 84 |
+
) -> Optional[Community]:
|
| 85 |
+
nonlocal used_n, used_e
|
| 86 |
+
|
| 87 |
+
community_nodes: Dict[str, dict] = {}
|
| 88 |
+
community_edges: Dict[frozenset[str], dict] = {}
|
| 89 |
+
queue: asyncio.Queue = asyncio.Queue()
|
| 90 |
+
token_sum = 0
|
| 91 |
+
|
| 92 |
+
async def _add_unit(u):
|
| 93 |
+
nonlocal token_sum
|
| 94 |
+
t, i, d = u
|
| 95 |
+
if t == NODE_UNIT: # node
|
| 96 |
+
if i in used_n or i in community_nodes:
|
| 97 |
+
return False
|
| 98 |
+
community_nodes[i] = d
|
| 99 |
+
used_n.add(i)
|
| 100 |
+
else: # edge
|
| 101 |
+
if i in used_e or i in community_edges:
|
| 102 |
+
return False
|
| 103 |
+
community_edges[i] = d
|
| 104 |
+
used_e.add(i)
|
| 105 |
+
token_sum += d.get("length", 0)
|
| 106 |
+
return True
|
| 107 |
+
|
| 108 |
+
await _add_unit(seed_unit)
|
| 109 |
+
await queue.put(seed_unit)
|
| 110 |
+
|
| 111 |
+
# BFS
|
| 112 |
+
while not queue.empty():
|
| 113 |
+
if (
|
| 114 |
+
len(community_nodes) + len(community_edges)
|
| 115 |
+
>= max_units_per_community
|
| 116 |
+
or token_sum >= max_tokens_per_community
|
| 117 |
+
):
|
| 118 |
+
break
|
| 119 |
+
|
| 120 |
+
cur_type, cur_id, _ = await queue.get()
|
| 121 |
+
|
| 122 |
+
neighbors: List[Tuple[str, Any, dict]] = []
|
| 123 |
+
if cur_type == NODE_UNIT:
|
| 124 |
+
for nb_id in adj.get(cur_id, []):
|
| 125 |
+
e_key = frozenset((cur_id, nb_id))
|
| 126 |
+
if e_key not in used_e and e_key not in community_edges:
|
| 127 |
+
neighbors.append((EDGE_UNIT, e_key, edge_dict[e_key]))
|
| 128 |
+
else:
|
| 129 |
+
for n_id in cur_id:
|
| 130 |
+
if n_id not in used_n and n_id not in community_nodes:
|
| 131 |
+
neighbors.append((NODE_UNIT, n_id, node_dict[n_id]))
|
| 132 |
+
|
| 133 |
+
neighbors = self._sort_units(neighbors, unit_sampling)
|
| 134 |
+
for nb in neighbors:
|
| 135 |
+
if (
|
| 136 |
+
len(community_nodes) + len(community_edges)
|
| 137 |
+
>= max_units_per_community
|
| 138 |
+
or token_sum >= max_tokens_per_community
|
| 139 |
+
):
|
| 140 |
+
break
|
| 141 |
+
if await _add_unit(nb):
|
| 142 |
+
await queue.put(nb)
|
| 143 |
+
|
| 144 |
+
if len(community_nodes) + len(community_edges) < min_units_per_community:
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
return Community(
|
| 148 |
+
id=len(communities),
|
| 149 |
+
nodes=list(community_nodes.keys()),
|
| 150 |
+
edges=[(u, v) for (u, v), _ in community_edges.items()],
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
async for unit in tqdm_async(all_units, desc="ECE partition"):
|
| 154 |
+
utype, uid, _ = unit
|
| 155 |
+
if (utype == NODE_UNIT and uid in used_n) or (
|
| 156 |
+
utype == EDGE_UNIT and uid in used_e
|
| 157 |
+
):
|
| 158 |
+
continue
|
| 159 |
+
comm = await _grow_community(unit)
|
| 160 |
+
if comm is not None:
|
| 161 |
+
communities.append(comm)
|
| 162 |
+
|
| 163 |
+
return communities
|
graphgen/models/partitioner/leiden_partitioner.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, Dict, List, Set, Tuple
|
| 4 |
+
|
| 5 |
+
import igraph as ig
|
| 6 |
+
from leidenalg import ModularityVertexPartition, find_partition
|
| 7 |
+
|
| 8 |
+
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
| 9 |
+
from graphgen.bases.datatypes import Community
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class LeidenPartitioner(BasePartitioner):
|
| 14 |
+
"""
|
| 15 |
+
Leiden partitioner that partitions the graph into communities using the Leiden algorithm.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
async def partition(
|
| 19 |
+
self,
|
| 20 |
+
g: BaseGraphStorage,
|
| 21 |
+
max_size: int = 20,
|
| 22 |
+
use_lcc: bool = False,
|
| 23 |
+
random_seed: int = 42,
|
| 24 |
+
**kwargs: Any,
|
| 25 |
+
) -> List[Community]:
|
| 26 |
+
"""
|
| 27 |
+
Leiden Partition follows these steps:
|
| 28 |
+
1. export the graph from graph storage
|
| 29 |
+
2. use the leiden algorithm to detect communities, get {node: community_id}
|
| 30 |
+
3. split large communities if max_size is given
|
| 31 |
+
4. convert {node: community_id} to List[Community]
|
| 32 |
+
:param g
|
| 33 |
+
:param max_size: maximum size of each community, if None or <=0, no limit
|
| 34 |
+
:param use_lcc: whether to use the largest connected component only
|
| 35 |
+
:param random_seed
|
| 36 |
+
:param kwargs: other parameters for the leiden algorithm
|
| 37 |
+
:return:
|
| 38 |
+
"""
|
| 39 |
+
nodes = await g.get_all_nodes() # List[Tuple[str, dict]]
|
| 40 |
+
edges = await g.get_all_edges() # List[Tuple[str, str, dict]]
|
| 41 |
+
|
| 42 |
+
node2cid: Dict[str, int] = await self._run_leiden(
|
| 43 |
+
nodes, edges, use_lcc, random_seed
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if max_size is not None and max_size > 0:
|
| 47 |
+
node2cid = await self._split_communities(node2cid, max_size)
|
| 48 |
+
|
| 49 |
+
cid2nodes: Dict[int, List[str]] = defaultdict(list)
|
| 50 |
+
for n, cid in node2cid.items():
|
| 51 |
+
cid2nodes[cid].append(n)
|
| 52 |
+
|
| 53 |
+
communities: List[Community] = []
|
| 54 |
+
for cid, nodes in cid2nodes.items():
|
| 55 |
+
node_set: Set[str] = set(nodes)
|
| 56 |
+
comm_edges: List[Tuple[str, str]] = [
|
| 57 |
+
(u, v) for u, v, _ in edges if u in node_set and v in node_set
|
| 58 |
+
]
|
| 59 |
+
communities.append(Community(id=cid, nodes=nodes, edges=comm_edges))
|
| 60 |
+
return communities
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
async def _run_leiden(
|
| 64 |
+
nodes: List[Tuple[str, dict]],
|
| 65 |
+
edges: List[Tuple[str, str, dict]],
|
| 66 |
+
use_lcc: bool = False,
|
| 67 |
+
random_seed: int = 42,
|
| 68 |
+
) -> Dict[str, int]:
|
| 69 |
+
# build igraph
|
| 70 |
+
ig_graph = ig.Graph.TupleList(((u, v) for u, v, _ in edges), directed=False)
|
| 71 |
+
|
| 72 |
+
# remove isolated nodes
|
| 73 |
+
ig_graph.delete_vertices(ig_graph.vs.select(_degree_eq=0))
|
| 74 |
+
|
| 75 |
+
node2cid: Dict[str, int] = {}
|
| 76 |
+
if use_lcc:
|
| 77 |
+
lcc = ig_graph.components().giant()
|
| 78 |
+
partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed)
|
| 79 |
+
for part_id, cluster in enumerate(partition):
|
| 80 |
+
for v in cluster:
|
| 81 |
+
node2cid[lcc.vs[v]["name"]] = part_id
|
| 82 |
+
else:
|
| 83 |
+
offset = 0
|
| 84 |
+
for component in ig_graph.components():
|
| 85 |
+
subgraph = ig_graph.induced_subgraph(component)
|
| 86 |
+
partition = find_partition(
|
| 87 |
+
subgraph, ModularityVertexPartition, seed=random_seed
|
| 88 |
+
)
|
| 89 |
+
for part_id, cluster in enumerate(partition):
|
| 90 |
+
for v in cluster:
|
| 91 |
+
original_node = subgraph.vs[v]["name"]
|
| 92 |
+
node2cid[original_node] = part_id + offset
|
| 93 |
+
offset += len(partition)
|
| 94 |
+
return node2cid
|
| 95 |
+
|
| 96 |
+
@staticmethod
|
| 97 |
+
async def _split_communities(
|
| 98 |
+
node2cid: Dict[str, int], max_size: int
|
| 99 |
+
) -> Dict[str, int]:
|
| 100 |
+
"""
|
| 101 |
+
Split communities larger than max_size into smaller sub-communities.
|
| 102 |
+
"""
|
| 103 |
+
cid2nodes: Dict[int, List[str]] = defaultdict(list)
|
| 104 |
+
for n, cid in node2cid.items():
|
| 105 |
+
cid2nodes[cid].append(n)
|
| 106 |
+
|
| 107 |
+
new_mapping: Dict[str, int] = {}
|
| 108 |
+
new_cid = 0
|
| 109 |
+
for nodes in cid2nodes.values():
|
| 110 |
+
if len(nodes) <= max_size:
|
| 111 |
+
for n in nodes:
|
| 112 |
+
new_mapping[n] = new_cid
|
| 113 |
+
new_cid += 1
|
| 114 |
+
else:
|
| 115 |
+
for start in range(0, len(nodes), max_size):
|
| 116 |
+
chunk = nodes[start : start + max_size]
|
| 117 |
+
for n in chunk:
|
| 118 |
+
new_mapping[n] = new_cid
|
| 119 |
+
new_cid += 1
|
| 120 |
+
return new_mapping
|
graphgen/models/storage/__init__.py
CHANGED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .json_storage import JsonKVStorage, JsonListStorage
|
| 2 |
+
from .networkx_storage import NetworkXStorage
|
graphgen/models/storage/networkx_storage.py
CHANGED
|
@@ -102,8 +102,8 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 102 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
| 103 |
return self._graph.nodes.get(node_id)
|
| 104 |
|
| 105 |
-
async def get_all_nodes(self) -> Union[list[dict], None]:
|
| 106 |
-
return self._graph.nodes(data=True)
|
| 107 |
|
| 108 |
async def node_degree(self, node_id: str) -> int:
|
| 109 |
return self._graph.degree(node_id)
|
|
@@ -116,8 +116,8 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 116 |
) -> Union[dict, None]:
|
| 117 |
return self._graph.edges.get((source_node_id, target_node_id))
|
| 118 |
|
| 119 |
-
async def get_all_edges(self) -> Union[list[dict], None]:
|
| 120 |
-
return self._graph.edges(data=True)
|
| 121 |
|
| 122 |
async def get_node_edges(
|
| 123 |
self, source_node_id: str
|
|
|
|
| 102 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
| 103 |
return self._graph.nodes.get(node_id)
|
| 104 |
|
| 105 |
+
async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
|
| 106 |
+
return list(self._graph.nodes(data=True))
|
| 107 |
|
| 108 |
async def node_degree(self, node_id: str) -> int:
|
| 109 |
return self._graph.degree(node_id)
|
|
|
|
| 116 |
) -> Union[dict, None]:
|
| 117 |
return self._graph.edges.get((source_node_id, target_node_id))
|
| 118 |
|
| 119 |
+
async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
|
| 120 |
+
return list(self._graph.edges(data=True))
|
| 121 |
|
| 122 |
async def get_node_edges(
|
| 123 |
self, source_node_id: str
|
graphgen/operators/__init__.py
CHANGED
|
@@ -1,13 +1,8 @@
|
|
| 1 |
-
from
|
| 2 |
-
from
|
| 3 |
-
from graphgen.operators.search.search_all import search_all
|
| 4 |
-
|
| 5 |
from .judge import judge_statement
|
|
|
|
| 6 |
from .quiz import quiz
|
| 7 |
from .read import read_files
|
|
|
|
| 8 |
from .split import chunk_documents
|
| 9 |
-
from .traverse_graph import (
|
| 10 |
-
traverse_graph_for_aggregated,
|
| 11 |
-
traverse_graph_for_atomic,
|
| 12 |
-
traverse_graph_for_multi_hop,
|
| 13 |
-
)
|
|
|
|
| 1 |
+
from .build_kg import build_kg
|
| 2 |
+
from .generate import generate_qas
|
|
|
|
|
|
|
| 3 |
from .judge import judge_statement
|
| 4 |
+
from .partition import partition_kg
|
| 5 |
from .quiz import quiz
|
| 6 |
from .read import read_files
|
| 7 |
+
from .search import search_all
|
| 8 |
from .split import chunk_documents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/operators/build_kg/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .build_kg import build_kg
|
graphgen/operators/build_kg/split_kg.py
DELETED
|
@@ -1,382 +0,0 @@
|
|
| 1 |
-
import random
|
| 2 |
-
from collections import defaultdict
|
| 3 |
-
from typing import Dict
|
| 4 |
-
|
| 5 |
-
from tqdm.asyncio import tqdm as tqdm_async
|
| 6 |
-
|
| 7 |
-
from graphgen.models import NetworkXStorage
|
| 8 |
-
from graphgen.utils import logger
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
async def _get_node_info(
|
| 12 |
-
node_id: str,
|
| 13 |
-
graph_storage: NetworkXStorage,
|
| 14 |
-
) -> dict:
|
| 15 |
-
"""
|
| 16 |
-
Get node info
|
| 17 |
-
|
| 18 |
-
:param node_id: node id
|
| 19 |
-
:param graph_storage: graph storage instance
|
| 20 |
-
:return: node info
|
| 21 |
-
"""
|
| 22 |
-
node_data = await graph_storage.get_node(node_id)
|
| 23 |
-
return {"node_id": node_id, **node_data}
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def _get_level_n_edges_by_max_width(
|
| 27 |
-
edge_adj_list: dict,
|
| 28 |
-
node_dict: dict,
|
| 29 |
-
edges: list,
|
| 30 |
-
nodes,
|
| 31 |
-
src_edge: tuple,
|
| 32 |
-
max_depth: int,
|
| 33 |
-
bidirectional: bool,
|
| 34 |
-
max_extra_edges: int,
|
| 35 |
-
edge_sampling: str,
|
| 36 |
-
loss_strategy: str = "only_edge",
|
| 37 |
-
) -> list:
|
| 38 |
-
"""
|
| 39 |
-
Get level n edges for an edge.
|
| 40 |
-
n is decided by max_depth in traverse_strategy
|
| 41 |
-
|
| 42 |
-
:param edge_adj_list
|
| 43 |
-
:param node_dict
|
| 44 |
-
:param edges
|
| 45 |
-
:param nodes
|
| 46 |
-
:param src_edge
|
| 47 |
-
:param max_depth
|
| 48 |
-
:param bidirectional
|
| 49 |
-
:param max_extra_edges
|
| 50 |
-
:param edge_sampling
|
| 51 |
-
:return: level n edges
|
| 52 |
-
"""
|
| 53 |
-
src_id, tgt_id, _ = src_edge
|
| 54 |
-
|
| 55 |
-
level_n_edges = []
|
| 56 |
-
|
| 57 |
-
start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id}
|
| 58 |
-
|
| 59 |
-
while max_depth > 0 and max_extra_edges > 0:
|
| 60 |
-
max_depth -= 1
|
| 61 |
-
|
| 62 |
-
candidate_edges = [
|
| 63 |
-
edges[edge_id]
|
| 64 |
-
for node in start_nodes
|
| 65 |
-
for edge_id in edge_adj_list[node]
|
| 66 |
-
if not edges[edge_id][2].get("visited", False)
|
| 67 |
-
]
|
| 68 |
-
|
| 69 |
-
if not candidate_edges:
|
| 70 |
-
break
|
| 71 |
-
|
| 72 |
-
if len(candidate_edges) >= max_extra_edges:
|
| 73 |
-
if loss_strategy == "both":
|
| 74 |
-
er_tuples = [
|
| 75 |
-
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
|
| 76 |
-
for edge in candidate_edges
|
| 77 |
-
]
|
| 78 |
-
candidate_edges = _sort_tuples(er_tuples, edge_sampling)[
|
| 79 |
-
:max_extra_edges
|
| 80 |
-
]
|
| 81 |
-
elif loss_strategy == "only_edge":
|
| 82 |
-
candidate_edges = _sort_edges(candidate_edges, edge_sampling)[
|
| 83 |
-
:max_extra_edges
|
| 84 |
-
]
|
| 85 |
-
else:
|
| 86 |
-
raise ValueError(f"Invalid loss strategy: {loss_strategy}")
|
| 87 |
-
|
| 88 |
-
for edge in candidate_edges:
|
| 89 |
-
level_n_edges.append(edge)
|
| 90 |
-
edge[2]["visited"] = True
|
| 91 |
-
break
|
| 92 |
-
|
| 93 |
-
max_extra_edges -= len(candidate_edges)
|
| 94 |
-
new_start_nodes = set()
|
| 95 |
-
|
| 96 |
-
for edge in candidate_edges:
|
| 97 |
-
level_n_edges.append(edge)
|
| 98 |
-
edge[2]["visited"] = True
|
| 99 |
-
|
| 100 |
-
if not edge[0] in start_nodes:
|
| 101 |
-
new_start_nodes.add(edge[0])
|
| 102 |
-
if not edge[1] in start_nodes:
|
| 103 |
-
new_start_nodes.add(edge[1])
|
| 104 |
-
|
| 105 |
-
start_nodes = new_start_nodes
|
| 106 |
-
|
| 107 |
-
return level_n_edges
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def _get_level_n_edges_by_max_tokens(
|
| 111 |
-
edge_adj_list: dict,
|
| 112 |
-
node_dict: dict,
|
| 113 |
-
edges: list,
|
| 114 |
-
nodes: list,
|
| 115 |
-
src_edge: tuple,
|
| 116 |
-
max_depth: int,
|
| 117 |
-
bidirectional: bool,
|
| 118 |
-
max_tokens: int,
|
| 119 |
-
edge_sampling: str,
|
| 120 |
-
loss_strategy: str = "only_edge",
|
| 121 |
-
) -> list:
|
| 122 |
-
"""
|
| 123 |
-
Get level n edges for an edge.
|
| 124 |
-
n is decided by max_depth in traverse_strategy.
|
| 125 |
-
|
| 126 |
-
:param edge_adj_list
|
| 127 |
-
:param node_dict
|
| 128 |
-
:param edges
|
| 129 |
-
:param nodes
|
| 130 |
-
:param src_edge
|
| 131 |
-
:param max_depth
|
| 132 |
-
:param bidirectional
|
| 133 |
-
:param max_tokens
|
| 134 |
-
:param edge_sampling
|
| 135 |
-
:return: level n edges
|
| 136 |
-
"""
|
| 137 |
-
src_id, tgt_id, src_edge_data = src_edge
|
| 138 |
-
|
| 139 |
-
max_tokens -= (
|
| 140 |
-
src_edge_data["length"]
|
| 141 |
-
+ nodes[node_dict[src_id]][1]["length"]
|
| 142 |
-
+ nodes[node_dict[tgt_id]][1]["length"]
|
| 143 |
-
)
|
| 144 |
-
|
| 145 |
-
level_n_edges = []
|
| 146 |
-
|
| 147 |
-
start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id}
|
| 148 |
-
temp_nodes = {src_id, tgt_id}
|
| 149 |
-
|
| 150 |
-
while max_depth > 0 and max_tokens > 0:
|
| 151 |
-
max_depth -= 1
|
| 152 |
-
|
| 153 |
-
candidate_edges = [
|
| 154 |
-
edges[edge_id]
|
| 155 |
-
for node in start_nodes
|
| 156 |
-
for edge_id in edge_adj_list[node]
|
| 157 |
-
if not edges[edge_id][2].get("visited", False)
|
| 158 |
-
]
|
| 159 |
-
|
| 160 |
-
if not candidate_edges:
|
| 161 |
-
break
|
| 162 |
-
|
| 163 |
-
if loss_strategy == "both":
|
| 164 |
-
er_tuples = [
|
| 165 |
-
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
|
| 166 |
-
for edge in candidate_edges
|
| 167 |
-
]
|
| 168 |
-
candidate_edges = _sort_tuples(er_tuples, edge_sampling)
|
| 169 |
-
elif loss_strategy == "only_edge":
|
| 170 |
-
candidate_edges = _sort_edges(candidate_edges, edge_sampling)
|
| 171 |
-
else:
|
| 172 |
-
raise ValueError(f"Invalid loss strategy: {loss_strategy}")
|
| 173 |
-
|
| 174 |
-
for edge in candidate_edges:
|
| 175 |
-
max_tokens -= edge[2]["length"]
|
| 176 |
-
if not edge[0] in temp_nodes:
|
| 177 |
-
max_tokens -= nodes[node_dict[edge[0]]][1]["length"]
|
| 178 |
-
if not edge[1] in temp_nodes:
|
| 179 |
-
max_tokens -= nodes[node_dict[edge[1]]][1]["length"]
|
| 180 |
-
|
| 181 |
-
if max_tokens < 0:
|
| 182 |
-
return level_n_edges
|
| 183 |
-
|
| 184 |
-
level_n_edges.append(edge)
|
| 185 |
-
edge[2]["visited"] = True
|
| 186 |
-
temp_nodes.add(edge[0])
|
| 187 |
-
temp_nodes.add(edge[1])
|
| 188 |
-
|
| 189 |
-
new_start_nodes = set()
|
| 190 |
-
for edge in candidate_edges:
|
| 191 |
-
if not edge[0] in start_nodes:
|
| 192 |
-
new_start_nodes.add(edge[0])
|
| 193 |
-
if not edge[1] in start_nodes:
|
| 194 |
-
new_start_nodes.add(edge[1])
|
| 195 |
-
|
| 196 |
-
start_nodes = new_start_nodes
|
| 197 |
-
|
| 198 |
-
return level_n_edges
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
def _sort_tuples(er_tuples: list, edge_sampling: str) -> list:
|
| 202 |
-
"""
|
| 203 |
-
Sort edges with edge sampling strategy
|
| 204 |
-
|
| 205 |
-
:param er_tuples: [(nodes:list, edge:tuple)]
|
| 206 |
-
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
|
| 207 |
-
:return: sorted edges
|
| 208 |
-
"""
|
| 209 |
-
if edge_sampling == "random":
|
| 210 |
-
er_tuples = random.sample(er_tuples, len(er_tuples))
|
| 211 |
-
elif edge_sampling == "min_loss":
|
| 212 |
-
er_tuples = sorted(
|
| 213 |
-
er_tuples,
|
| 214 |
-
key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
|
| 215 |
-
)
|
| 216 |
-
elif edge_sampling == "max_loss":
|
| 217 |
-
er_tuples = sorted(
|
| 218 |
-
er_tuples,
|
| 219 |
-
key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
|
| 220 |
-
reverse=True,
|
| 221 |
-
)
|
| 222 |
-
else:
|
| 223 |
-
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
|
| 224 |
-
edges = [edge for _, edge in er_tuples]
|
| 225 |
-
return edges
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
def _sort_edges(edges: list, edge_sampling: str) -> list:
|
| 229 |
-
"""
|
| 230 |
-
Sort edges with edge sampling strategy
|
| 231 |
-
|
| 232 |
-
:param edges: total edges
|
| 233 |
-
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
|
| 234 |
-
:return: sorted edges
|
| 235 |
-
"""
|
| 236 |
-
if edge_sampling == "random":
|
| 237 |
-
random.shuffle(edges)
|
| 238 |
-
elif edge_sampling == "min_loss":
|
| 239 |
-
edges = sorted(edges, key=lambda x: x[2]["loss"])
|
| 240 |
-
elif edge_sampling == "max_loss":
|
| 241 |
-
edges = sorted(edges, key=lambda x: x[2]["loss"], reverse=True)
|
| 242 |
-
else:
|
| 243 |
-
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
|
| 244 |
-
return edges
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
async def get_batches_with_strategy( # pylint: disable=too-many-branches
|
| 248 |
-
nodes: list,
|
| 249 |
-
edges: list,
|
| 250 |
-
graph_storage: NetworkXStorage,
|
| 251 |
-
traverse_strategy: Dict,
|
| 252 |
-
):
|
| 253 |
-
expand_method = traverse_strategy["expand_method"]
|
| 254 |
-
if expand_method == "max_width":
|
| 255 |
-
logger.info("Using max width strategy")
|
| 256 |
-
elif expand_method == "max_tokens":
|
| 257 |
-
logger.info("Using max tokens strategy")
|
| 258 |
-
else:
|
| 259 |
-
raise ValueError(f"Invalid expand method: {expand_method}")
|
| 260 |
-
|
| 261 |
-
max_depth = traverse_strategy["max_depth"]
|
| 262 |
-
edge_sampling = traverse_strategy["edge_sampling"]
|
| 263 |
-
|
| 264 |
-
# 构建临接矩阵
|
| 265 |
-
edge_adj_list = defaultdict(list)
|
| 266 |
-
node_dict = {}
|
| 267 |
-
processing_batches = []
|
| 268 |
-
|
| 269 |
-
node_cache = {}
|
| 270 |
-
|
| 271 |
-
async def get_cached_node_info(node_id: str) -> dict:
|
| 272 |
-
if node_id not in node_cache:
|
| 273 |
-
node_cache[node_id] = await _get_node_info(node_id, graph_storage)
|
| 274 |
-
return node_cache[node_id]
|
| 275 |
-
|
| 276 |
-
for i, (node_name, _) in enumerate(nodes):
|
| 277 |
-
node_dict[node_name] = i
|
| 278 |
-
|
| 279 |
-
if traverse_strategy["loss_strategy"] == "both":
|
| 280 |
-
er_tuples = [
|
| 281 |
-
([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
|
| 282 |
-
for edge in edges
|
| 283 |
-
]
|
| 284 |
-
edges = _sort_tuples(er_tuples, edge_sampling)
|
| 285 |
-
elif traverse_strategy["loss_strategy"] == "only_edge":
|
| 286 |
-
edges = _sort_edges(edges, edge_sampling)
|
| 287 |
-
else:
|
| 288 |
-
raise ValueError(f"Invalid loss strategy: {traverse_strategy['loss_strategy']}")
|
| 289 |
-
|
| 290 |
-
for i, (src, tgt, _) in enumerate(edges):
|
| 291 |
-
edge_adj_list[src].append(i)
|
| 292 |
-
edge_adj_list[tgt].append(i)
|
| 293 |
-
|
| 294 |
-
for edge in tqdm_async(edges, desc="Preparing batches"):
|
| 295 |
-
if "visited" in edge[2] and edge[2]["visited"]:
|
| 296 |
-
continue
|
| 297 |
-
|
| 298 |
-
edge[2]["visited"] = True
|
| 299 |
-
|
| 300 |
-
_process_nodes = []
|
| 301 |
-
_process_edges = []
|
| 302 |
-
|
| 303 |
-
src_id = edge[0]
|
| 304 |
-
tgt_id = edge[1]
|
| 305 |
-
|
| 306 |
-
_process_nodes.extend(
|
| 307 |
-
[await get_cached_node_info(src_id), await get_cached_node_info(tgt_id)]
|
| 308 |
-
)
|
| 309 |
-
_process_edges.append(edge)
|
| 310 |
-
|
| 311 |
-
if expand_method == "max_width":
|
| 312 |
-
level_n_edges = _get_level_n_edges_by_max_width(
|
| 313 |
-
edge_adj_list,
|
| 314 |
-
node_dict,
|
| 315 |
-
edges,
|
| 316 |
-
nodes,
|
| 317 |
-
edge,
|
| 318 |
-
max_depth,
|
| 319 |
-
traverse_strategy["bidirectional"],
|
| 320 |
-
traverse_strategy["max_extra_edges"],
|
| 321 |
-
edge_sampling,
|
| 322 |
-
traverse_strategy["loss_strategy"],
|
| 323 |
-
)
|
| 324 |
-
else:
|
| 325 |
-
level_n_edges = _get_level_n_edges_by_max_tokens(
|
| 326 |
-
edge_adj_list,
|
| 327 |
-
node_dict,
|
| 328 |
-
edges,
|
| 329 |
-
nodes,
|
| 330 |
-
edge,
|
| 331 |
-
max_depth,
|
| 332 |
-
traverse_strategy["bidirectional"],
|
| 333 |
-
traverse_strategy["max_tokens"],
|
| 334 |
-
edge_sampling,
|
| 335 |
-
traverse_strategy["loss_strategy"],
|
| 336 |
-
)
|
| 337 |
-
|
| 338 |
-
for _edge in level_n_edges:
|
| 339 |
-
_process_nodes.append(await get_cached_node_info(_edge[0]))
|
| 340 |
-
_process_nodes.append(await get_cached_node_info(_edge[1]))
|
| 341 |
-
_process_edges.append(_edge)
|
| 342 |
-
|
| 343 |
-
# 去重
|
| 344 |
-
_process_nodes = list(
|
| 345 |
-
{node["node_id"]: node for node in _process_nodes}.values()
|
| 346 |
-
)
|
| 347 |
-
_process_edges = list(
|
| 348 |
-
{(edge[0], edge[1]): edge for edge in _process_edges}.values()
|
| 349 |
-
)
|
| 350 |
-
|
| 351 |
-
processing_batches.append((_process_nodes, _process_edges))
|
| 352 |
-
|
| 353 |
-
logger.info("Processing batches: %d", len(processing_batches))
|
| 354 |
-
|
| 355 |
-
# isolate nodes
|
| 356 |
-
isolated_node_strategy = traverse_strategy["isolated_node_strategy"]
|
| 357 |
-
if isolated_node_strategy == "add":
|
| 358 |
-
processing_batches = await _add_isolated_nodes(
|
| 359 |
-
nodes, processing_batches, graph_storage
|
| 360 |
-
)
|
| 361 |
-
logger.info(
|
| 362 |
-
"Processing batches after adding isolated nodes: %d",
|
| 363 |
-
len(processing_batches),
|
| 364 |
-
)
|
| 365 |
-
|
| 366 |
-
return processing_batches
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
async def _add_isolated_nodes(
|
| 370 |
-
nodes: list,
|
| 371 |
-
processing_batches: list,
|
| 372 |
-
graph_storage: NetworkXStorage,
|
| 373 |
-
) -> list:
|
| 374 |
-
visited_nodes = set()
|
| 375 |
-
for _process_nodes, _process_edges in processing_batches:
|
| 376 |
-
for node in _process_nodes:
|
| 377 |
-
visited_nodes.add(node["node_id"])
|
| 378 |
-
for node in nodes:
|
| 379 |
-
if node[0] not in visited_nodes:
|
| 380 |
-
_process_nodes = [await _get_node_info(node[0], graph_storage)]
|
| 381 |
-
processing_batches.append((_process_nodes, []))
|
| 382 |
-
return processing_batches
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/operators/generate/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .generate_qas import generate_qas
|
graphgen/operators/generate/generate_cot.py
DELETED
|
@@ -1,117 +0,0 @@
|
|
| 1 |
-
import asyncio
|
| 2 |
-
from typing import Dict, List, Tuple
|
| 3 |
-
|
| 4 |
-
from tqdm.asyncio import tqdm as tqdm_async
|
| 5 |
-
|
| 6 |
-
from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIClient
|
| 7 |
-
from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
|
| 8 |
-
from graphgen.utils import compute_content_hash, detect_main_language
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
async def generate_cot(
|
| 12 |
-
graph_storage: NetworkXStorage,
|
| 13 |
-
synthesizer_llm_client: OpenAIClient,
|
| 14 |
-
method_params: Dict = None,
|
| 15 |
-
):
|
| 16 |
-
method = method_params.get("method", "leiden")
|
| 17 |
-
detector = CommunityDetector(
|
| 18 |
-
graph_storage=graph_storage, method=method, method_params=method_params
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
results = await detector.detect_communities()
|
| 22 |
-
|
| 23 |
-
# Convert results to a format suitable for summarization
|
| 24 |
-
communities = {}
|
| 25 |
-
for node, community_id in results.items():
|
| 26 |
-
if community_id not in communities:
|
| 27 |
-
communities[community_id] = []
|
| 28 |
-
communities[community_id].append(node)
|
| 29 |
-
|
| 30 |
-
if not communities:
|
| 31 |
-
return {}
|
| 32 |
-
|
| 33 |
-
semaphore = asyncio.Semaphore(value=1000)
|
| 34 |
-
|
| 35 |
-
async def _generate_from_single_community(
|
| 36 |
-
c_id: int, nodes: List[str]
|
| 37 |
-
) -> Tuple[int, Tuple[str, str, str]]:
|
| 38 |
-
"""Summarize a single community."""
|
| 39 |
-
async with semaphore:
|
| 40 |
-
entities: List[str] = []
|
| 41 |
-
relationships: List[str] = []
|
| 42 |
-
|
| 43 |
-
for n in nodes:
|
| 44 |
-
node_data = await graph_storage.get_node(n)
|
| 45 |
-
if node_data is not None:
|
| 46 |
-
entities.append(f"({n}: {node_data.get('description')})")
|
| 47 |
-
|
| 48 |
-
edges = await graph_storage.get_node_edges(n)
|
| 49 |
-
for edge in edges:
|
| 50 |
-
target = edge[1]
|
| 51 |
-
if target in nodes:
|
| 52 |
-
edge_data = await graph_storage.get_edge(n, target)
|
| 53 |
-
relationships.append(
|
| 54 |
-
f"({n}) - [{edge_data['description']}] -> ({target})"
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
entities_str = "\n".join(entities)
|
| 58 |
-
relationships_str = "\n".join(relationships)
|
| 59 |
-
|
| 60 |
-
language = (
|
| 61 |
-
"English"
|
| 62 |
-
if detect_main_language(entities_str + relationships_str) == "en"
|
| 63 |
-
else "Chinese"
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
prompt = COT_TEMPLATE_DESIGN_PROMPT[language]["TEMPLATE"].format(
|
| 67 |
-
entities=entities_str,
|
| 68 |
-
relationships=relationships_str,
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
cot_template = await synthesizer_llm_client.generate_answer(prompt)
|
| 72 |
-
|
| 73 |
-
if "问题:" in cot_template and "推理路径设计:" in cot_template:
|
| 74 |
-
question = cot_template.split("问题:")[1].split("推理路径设计:")[0].strip()
|
| 75 |
-
reasoning_path = cot_template.split("推理路径设计:")[1].strip()
|
| 76 |
-
elif (
|
| 77 |
-
"Question:" in cot_template and "Reasoning-Path Design:" in cot_template
|
| 78 |
-
):
|
| 79 |
-
question = (
|
| 80 |
-
cot_template.split("Question:")[1]
|
| 81 |
-
.split("Reasoning-Path Design:")[0]
|
| 82 |
-
.strip()
|
| 83 |
-
)
|
| 84 |
-
reasoning_path = cot_template.split("Reasoning-Path Design:")[1].strip()
|
| 85 |
-
else:
|
| 86 |
-
raise ValueError("COT template format is incorrect.")
|
| 87 |
-
|
| 88 |
-
prompt = COT_GENERATION_PROMPT[language]["TEMPLATE"].format(
|
| 89 |
-
entities=entities_str,
|
| 90 |
-
relationships=relationships_str,
|
| 91 |
-
question=question,
|
| 92 |
-
reasoning_template=reasoning_path,
|
| 93 |
-
)
|
| 94 |
-
|
| 95 |
-
cot_answer = await synthesizer_llm_client.generate_answer(prompt)
|
| 96 |
-
|
| 97 |
-
return c_id, (question, reasoning_path, cot_answer)
|
| 98 |
-
|
| 99 |
-
cid_nodes = list(communities.items())
|
| 100 |
-
|
| 101 |
-
results: Dict = {}
|
| 102 |
-
async for coro in tqdm_async(
|
| 103 |
-
asyncio.as_completed(
|
| 104 |
-
[_generate_from_single_community(cid, nodes) for cid, nodes in cid_nodes]
|
| 105 |
-
),
|
| 106 |
-
total=len(cid_nodes),
|
| 107 |
-
desc="[Generating COT] Generating CoT data from communities",
|
| 108 |
-
unit="community",
|
| 109 |
-
):
|
| 110 |
-
cid, (q, r, a) = await coro
|
| 111 |
-
results[compute_content_hash(q)] = {
|
| 112 |
-
"question": q,
|
| 113 |
-
"reasoning_path": r,
|
| 114 |
-
"answer": a,
|
| 115 |
-
}
|
| 116 |
-
|
| 117 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/operators/generate/generate_qas.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from graphgen.bases import BaseLLMClient
|
| 4 |
+
from graphgen.models import (
|
| 5 |
+
AggregatedGenerator,
|
| 6 |
+
AtomicGenerator,
|
| 7 |
+
CoTGenerator,
|
| 8 |
+
MultiHopGenerator,
|
| 9 |
+
)
|
| 10 |
+
from graphgen.utils import logger, run_concurrent
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
async def generate_qas(
|
| 14 |
+
llm_client: BaseLLMClient,
|
| 15 |
+
batches: list[
|
| 16 |
+
tuple[
|
| 17 |
+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 18 |
+
]
|
| 19 |
+
],
|
| 20 |
+
generation_config: dict,
|
| 21 |
+
) -> list[dict[str, Any]]:
|
| 22 |
+
"""
|
| 23 |
+
Generate question-answer pairs based on nodes and edges.
|
| 24 |
+
:param llm_client: LLM client
|
| 25 |
+
:param batches
|
| 26 |
+
:param generation_config
|
| 27 |
+
:return: QA pairs
|
| 28 |
+
"""
|
| 29 |
+
mode = generation_config["mode"]
|
| 30 |
+
logger.info("[Generation] mode: %s, batches: %d", mode, len(batches))
|
| 31 |
+
|
| 32 |
+
if mode == "atomic":
|
| 33 |
+
generator = AtomicGenerator(llm_client)
|
| 34 |
+
elif mode == "aggregated":
|
| 35 |
+
generator = AggregatedGenerator(llm_client)
|
| 36 |
+
elif mode == "multi_hop":
|
| 37 |
+
generator = MultiHopGenerator(llm_client)
|
| 38 |
+
elif mode == "cot":
|
| 39 |
+
generator = CoTGenerator(llm_client)
|
| 40 |
+
else:
|
| 41 |
+
raise ValueError(f"Unsupported generation mode: {mode}")
|
| 42 |
+
|
| 43 |
+
results = await run_concurrent(
|
| 44 |
+
generator.generate,
|
| 45 |
+
batches,
|
| 46 |
+
desc="[4/4]Generating QAs",
|
| 47 |
+
unit="batch",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# format
|
| 51 |
+
data_format = generation_config["data_format"]
|
| 52 |
+
logger.info("Output data format: %s", data_format)
|
| 53 |
+
|
| 54 |
+
results = generator.format_generation_results(
|
| 55 |
+
results, output_data_format=data_format
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
return results
|
graphgen/operators/partition/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .partition_kg import partition_kg
|
graphgen/operators/partition/partition_kg.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from graphgen.bases import BaseGraphStorage, BaseTokenizer
|
| 4 |
+
from graphgen.models import (
|
| 5 |
+
BFSPartitioner,
|
| 6 |
+
DFSPartitioner,
|
| 7 |
+
ECEPartitioner,
|
| 8 |
+
LeidenPartitioner,
|
| 9 |
+
)
|
| 10 |
+
from graphgen.utils import logger
|
| 11 |
+
|
| 12 |
+
from .pre_tokenize import pre_tokenize
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def partition_kg(
|
| 16 |
+
kg_instance: BaseGraphStorage,
|
| 17 |
+
tokenizer: Any = BaseTokenizer,
|
| 18 |
+
partition_config: dict = None,
|
| 19 |
+
) -> list[
|
| 20 |
+
tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]]
|
| 21 |
+
]:
|
| 22 |
+
method = partition_config["method"]
|
| 23 |
+
method_params = partition_config["method_params"]
|
| 24 |
+
if method == "bfs":
|
| 25 |
+
logger.info("Partitioning knowledge graph using BFS method.")
|
| 26 |
+
partitioner = BFSPartitioner()
|
| 27 |
+
elif method == "dfs":
|
| 28 |
+
logger.info("Partitioning knowledge graph using DFS method.")
|
| 29 |
+
partitioner = DFSPartitioner()
|
| 30 |
+
elif method == "ece":
|
| 31 |
+
logger.info("Partitioning knowledge graph using ECE method.")
|
| 32 |
+
# TODO: before ECE partitioning, we need to:
|
| 33 |
+
# 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random
|
| 34 |
+
# 2. pre-tokenize nodes and edges to get the token length
|
| 35 |
+
edges = await kg_instance.get_all_edges()
|
| 36 |
+
nodes = await kg_instance.get_all_nodes()
|
| 37 |
+
await pre_tokenize(kg_instance, tokenizer, edges, nodes)
|
| 38 |
+
partitioner = ECEPartitioner()
|
| 39 |
+
elif method == "leiden":
|
| 40 |
+
logger.info("Partitioning knowledge graph using Leiden method.")
|
| 41 |
+
partitioner = LeidenPartitioner()
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(f"Unsupported partition method: {method}")
|
| 44 |
+
|
| 45 |
+
communities = await partitioner.partition(g=kg_instance, **method_params)
|
| 46 |
+
logger.info("Partitioned the graph into %d communities.", len(communities))
|
| 47 |
+
batches = await partitioner.community2batch(communities, g=kg_instance)
|
| 48 |
+
return batches
|
graphgen/operators/partition/pre_tokenize.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
|
| 4 |
+
from graphgen.bases import BaseGraphStorage, BaseTokenizer
|
| 5 |
+
from graphgen.utils import run_concurrent
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
async def pre_tokenize(
|
| 9 |
+
graph_storage: BaseGraphStorage,
|
| 10 |
+
tokenizer: BaseTokenizer,
|
| 11 |
+
edges: List[Tuple],
|
| 12 |
+
nodes: List[Tuple],
|
| 13 |
+
) -> Tuple[List, List]:
|
| 14 |
+
"""为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。"""
|
| 15 |
+
sem = asyncio.Semaphore(1000)
|
| 16 |
+
|
| 17 |
+
async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple:
|
| 18 |
+
async with sem:
|
| 19 |
+
data = obj[1] if is_node else obj[2]
|
| 20 |
+
if "length" not in data:
|
| 21 |
+
loop = asyncio.get_event_loop()
|
| 22 |
+
data["length"] = len(
|
| 23 |
+
await loop.run_in_executor(
|
| 24 |
+
None, tokenizer.encode, data["description"]
|
| 25 |
+
)
|
| 26 |
+
)
|
| 27 |
+
if is_node:
|
| 28 |
+
await graph_storage.update_node(obj[0], obj[1])
|
| 29 |
+
else:
|
| 30 |
+
await graph_storage.update_edge(obj[0], obj[1], obj[2])
|
| 31 |
+
return obj
|
| 32 |
+
|
| 33 |
+
new_edges, new_nodes = await asyncio.gather(
|
| 34 |
+
run_concurrent(
|
| 35 |
+
lambda e: _patch_and_write(e, is_node=False),
|
| 36 |
+
edges,
|
| 37 |
+
desc="Pre-tokenizing edges",
|
| 38 |
+
),
|
| 39 |
+
run_concurrent(
|
| 40 |
+
lambda n: _patch_and_write(n, is_node=True),
|
| 41 |
+
nodes,
|
| 42 |
+
desc="Pre-tokenizing nodes",
|
| 43 |
+
),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
await graph_storage.index_done_callback()
|
| 47 |
+
return new_edges, new_nodes
|
graphgen/operators/search/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .search_all import search_all
|
graphgen/operators/traverse_graph.py
DELETED
|
@@ -1,540 +0,0 @@
|
|
| 1 |
-
import asyncio
|
| 2 |
-
from typing import Dict
|
| 3 |
-
|
| 4 |
-
import gradio as gr
|
| 5 |
-
from tqdm.asyncio import tqdm as tqdm_async
|
| 6 |
-
|
| 7 |
-
from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient, Tokenizer
|
| 8 |
-
from graphgen.operators.build_kg.split_kg import get_batches_with_strategy
|
| 9 |
-
from graphgen.templates import (
|
| 10 |
-
ANSWER_REPHRASING_PROMPT,
|
| 11 |
-
MULTI_HOP_GENERATION_PROMPT,
|
| 12 |
-
QUESTION_GENERATION_PROMPT,
|
| 13 |
-
)
|
| 14 |
-
from graphgen.utils import compute_content_hash, detect_main_language, logger
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
async def _pre_tokenize(
|
| 18 |
-
graph_storage: NetworkXStorage, tokenizer: Tokenizer, edges: list, nodes: list
|
| 19 |
-
) -> tuple:
|
| 20 |
-
|
| 21 |
-
sem = asyncio.Semaphore(1000)
|
| 22 |
-
|
| 23 |
-
async def handle_edge(edge: tuple) -> tuple:
|
| 24 |
-
async with sem:
|
| 25 |
-
if "length" not in edge[2]:
|
| 26 |
-
edge[2]["length"] = len(
|
| 27 |
-
await asyncio.get_event_loop().run_in_executor(
|
| 28 |
-
None, tokenizer.encode, edge[2]["description"]
|
| 29 |
-
)
|
| 30 |
-
)
|
| 31 |
-
return edge
|
| 32 |
-
|
| 33 |
-
async def handle_node(node: dict) -> dict:
|
| 34 |
-
async with sem:
|
| 35 |
-
if "length" not in node[1]:
|
| 36 |
-
node[1]["length"] = len(
|
| 37 |
-
await asyncio.get_event_loop().run_in_executor(
|
| 38 |
-
None, tokenizer.encode, node[1]["description"]
|
| 39 |
-
)
|
| 40 |
-
)
|
| 41 |
-
return node
|
| 42 |
-
|
| 43 |
-
new_edges = []
|
| 44 |
-
new_nodes = []
|
| 45 |
-
|
| 46 |
-
for result in tqdm_async(
|
| 47 |
-
asyncio.as_completed([handle_edge(edge) for edge in edges]),
|
| 48 |
-
total=len(edges),
|
| 49 |
-
desc="Pre-tokenizing edges",
|
| 50 |
-
):
|
| 51 |
-
new_edge = await result
|
| 52 |
-
await graph_storage.update_edge(new_edge[0], new_edge[1], new_edge[2])
|
| 53 |
-
new_edges.append(new_edge)
|
| 54 |
-
|
| 55 |
-
for result in tqdm_async(
|
| 56 |
-
asyncio.as_completed([handle_node(node) for node in nodes]),
|
| 57 |
-
total=len(nodes),
|
| 58 |
-
desc="Pre-tokenizing nodes",
|
| 59 |
-
):
|
| 60 |
-
new_node = await result
|
| 61 |
-
await graph_storage.update_node(new_node[0], new_node[1])
|
| 62 |
-
new_nodes.append(new_node)
|
| 63 |
-
|
| 64 |
-
await graph_storage.index_done_callback()
|
| 65 |
-
return new_edges, new_nodes
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
async def _construct_rephrasing_prompt(
|
| 69 |
-
_process_nodes: list,
|
| 70 |
-
_process_edges: list,
|
| 71 |
-
text_chunks_storage: JsonKVStorage,
|
| 72 |
-
add_context: bool = False,
|
| 73 |
-
) -> str:
|
| 74 |
-
entities = [
|
| 75 |
-
f"{_process_node['node_id']}: {_process_node['description']}"
|
| 76 |
-
for _process_node in _process_nodes
|
| 77 |
-
]
|
| 78 |
-
relations = [
|
| 79 |
-
f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
|
| 80 |
-
for _process_edge in _process_edges
|
| 81 |
-
]
|
| 82 |
-
|
| 83 |
-
entities_str = "\n".join(
|
| 84 |
-
[f"{index + 1}. {entity}" for index, entity in enumerate(entities)]
|
| 85 |
-
)
|
| 86 |
-
relations_str = "\n".join(
|
| 87 |
-
[f"{index + 1}. {relation}" for index, relation in enumerate(relations)]
|
| 88 |
-
)
|
| 89 |
-
language = (
|
| 90 |
-
"Chinese"
|
| 91 |
-
if detect_main_language(entities_str + relations_str) == "zh"
|
| 92 |
-
else "English"
|
| 93 |
-
)
|
| 94 |
-
|
| 95 |
-
if add_context:
|
| 96 |
-
original_ids = [
|
| 97 |
-
node["source_id"].split("<SEP>")[0] for node in _process_nodes
|
| 98 |
-
] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
|
| 99 |
-
|
| 100 |
-
original_ids = list(set(original_ids))
|
| 101 |
-
original_text = await text_chunks_storage.get_by_ids(original_ids)
|
| 102 |
-
original_text = "\n".join(
|
| 103 |
-
[
|
| 104 |
-
f"{index + 1}. {text['content']}"
|
| 105 |
-
for index, text in enumerate(original_text)
|
| 106 |
-
]
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
prompt = ANSWER_REPHRASING_PROMPT[language]["CONTEXT_TEMPLATE"].format(
|
| 110 |
-
language=language,
|
| 111 |
-
original_text=original_text,
|
| 112 |
-
entities=entities_str,
|
| 113 |
-
relationships=relations_str,
|
| 114 |
-
)
|
| 115 |
-
return prompt
|
| 116 |
-
|
| 117 |
-
prompt = ANSWER_REPHRASING_PROMPT[language]["TEMPLATE"].format(
|
| 118 |
-
language=language, entities=entities_str, relationships=relations_str
|
| 119 |
-
)
|
| 120 |
-
return prompt
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def get_average_loss(batch: tuple, loss_strategy: str) -> float:
|
| 124 |
-
try:
|
| 125 |
-
if loss_strategy == "only_edge":
|
| 126 |
-
return sum(edge[2]["loss"] for edge in batch[1]) / len(batch[1])
|
| 127 |
-
if loss_strategy == "both":
|
| 128 |
-
return sum(edge[2]["loss"] for edge in batch[1]) + sum(
|
| 129 |
-
node["loss"] for node in batch[0]
|
| 130 |
-
) / (len(batch[0]) + len(batch[1]))
|
| 131 |
-
raise ValueError("Invalid loss strategy")
|
| 132 |
-
except Exception as e: # pylint: disable=broad-except
|
| 133 |
-
logger.warning(
|
| 134 |
-
"Loss not found in some nodes or edges, setting loss to -1.0: %s", e
|
| 135 |
-
)
|
| 136 |
-
return -1.0
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def _post_process_synthetic_data(data):
|
| 140 |
-
block = data.split("\n\n")
|
| 141 |
-
qas = []
|
| 142 |
-
for line in block:
|
| 143 |
-
if "Question:" in line and "Answer:" in line:
|
| 144 |
-
question = line.split("Question:")[1].split("Answer:")[0].strip()
|
| 145 |
-
answer = line.split("Answer:")[1].strip()
|
| 146 |
-
qas.append({"question": question, "answer": answer})
|
| 147 |
-
elif "问题:" in line and "答案:" in line:
|
| 148 |
-
question = line.split("问题:")[1].split("答案:")[0].strip()
|
| 149 |
-
answer = line.split("答案:")[1].strip()
|
| 150 |
-
qas.append({"question": question, "answer": answer})
|
| 151 |
-
elif "问题:" in line and "回答:" in line:
|
| 152 |
-
question = line.split("问题:")[1].split("回答:")[0].strip()
|
| 153 |
-
answer = line.split("回答:")[1].strip()
|
| 154 |
-
qas.append({"question": question, "answer": answer})
|
| 155 |
-
return qas
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
async def traverse_graph_for_aggregated(
|
| 159 |
-
llm_client: OpenAIClient,
|
| 160 |
-
tokenizer: Tokenizer,
|
| 161 |
-
graph_storage: NetworkXStorage,
|
| 162 |
-
traverse_strategy: Dict,
|
| 163 |
-
text_chunks_storage: JsonKVStorage,
|
| 164 |
-
progress_bar: gr.Progress = None,
|
| 165 |
-
max_concurrent: int = 1000,
|
| 166 |
-
) -> dict:
|
| 167 |
-
"""
|
| 168 |
-
Traverse the graph
|
| 169 |
-
|
| 170 |
-
:param llm_client
|
| 171 |
-
:param tokenizer
|
| 172 |
-
:param graph_storage
|
| 173 |
-
:param traverse_strategy
|
| 174 |
-
:param text_chunks_storage
|
| 175 |
-
:param progress_bar
|
| 176 |
-
:param max_concurrent
|
| 177 |
-
:return: question and answer
|
| 178 |
-
"""
|
| 179 |
-
|
| 180 |
-
semaphore = asyncio.Semaphore(max_concurrent)
|
| 181 |
-
|
| 182 |
-
async def _process_nodes_and_edges(
|
| 183 |
-
_process_nodes: list,
|
| 184 |
-
_process_edges: list,
|
| 185 |
-
) -> str:
|
| 186 |
-
prompt = await _construct_rephrasing_prompt(
|
| 187 |
-
_process_nodes, _process_edges, text_chunks_storage, add_context=False
|
| 188 |
-
)
|
| 189 |
-
context = await llm_client.generate_answer(prompt)
|
| 190 |
-
|
| 191 |
-
# post-process the context
|
| 192 |
-
if context.startswith("Rephrased Text:"):
|
| 193 |
-
context = context[len("Rephrased Text:") :].strip()
|
| 194 |
-
elif context.startswith("重述文本:"):
|
| 195 |
-
context = context[len("重述文本:") :].strip()
|
| 196 |
-
|
| 197 |
-
return context
|
| 198 |
-
|
| 199 |
-
async def _process_single_batch(
|
| 200 |
-
_process_batch: tuple, question_type: str = "single"
|
| 201 |
-
) -> dict:
|
| 202 |
-
async with semaphore:
|
| 203 |
-
context = await _process_nodes_and_edges(
|
| 204 |
-
_process_batch[0],
|
| 205 |
-
_process_batch[1],
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
language = "Chinese" if detect_main_language(context) == "zh" else "English"
|
| 209 |
-
pre_length = sum(node["length"] for node in _process_batch[0]) + sum(
|
| 210 |
-
edge[2]["length"] for edge in _process_batch[1]
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
if question_type == "single":
|
| 214 |
-
question = await llm_client.generate_answer(
|
| 215 |
-
QUESTION_GENERATION_PROMPT[language]["SINGLE_TEMPLATE"].format(
|
| 216 |
-
answer=context
|
| 217 |
-
)
|
| 218 |
-
)
|
| 219 |
-
if question.startswith("Question:"):
|
| 220 |
-
question = question[len("Question:") :].strip()
|
| 221 |
-
elif question.startswith("问题:"):
|
| 222 |
-
question = question[len("问题:") :].strip()
|
| 223 |
-
|
| 224 |
-
logger.info(
|
| 225 |
-
"%d nodes and %d edges processed",
|
| 226 |
-
len(_process_batch[0]),
|
| 227 |
-
len(_process_batch[1]),
|
| 228 |
-
)
|
| 229 |
-
logger.info("Pre-length: %s", pre_length)
|
| 230 |
-
logger.info("Question: %s", question)
|
| 231 |
-
logger.info("Answer: %s", context)
|
| 232 |
-
|
| 233 |
-
return {
|
| 234 |
-
compute_content_hash(context): {
|
| 235 |
-
"question": question,
|
| 236 |
-
"answer": context,
|
| 237 |
-
"loss": get_average_loss(
|
| 238 |
-
_process_batch, traverse_strategy["loss_strategy"]
|
| 239 |
-
),
|
| 240 |
-
}
|
| 241 |
-
}
|
| 242 |
-
|
| 243 |
-
content = await llm_client.generate_answer(
|
| 244 |
-
QUESTION_GENERATION_PROMPT[language]["MULTI_TEMPLATE"].format(
|
| 245 |
-
doc=context
|
| 246 |
-
)
|
| 247 |
-
)
|
| 248 |
-
qas = _post_process_synthetic_data(content)
|
| 249 |
-
|
| 250 |
-
if len(qas) == 0:
|
| 251 |
-
logger.error(
|
| 252 |
-
"Error occurred while processing batch, question or answer is None"
|
| 253 |
-
)
|
| 254 |
-
return {}
|
| 255 |
-
|
| 256 |
-
final_results = {}
|
| 257 |
-
logger.info(
|
| 258 |
-
"%d nodes and %d edges processed",
|
| 259 |
-
len(_process_batch[0]),
|
| 260 |
-
len(_process_batch[1]),
|
| 261 |
-
)
|
| 262 |
-
logger.info("Pre-length: %s", pre_length)
|
| 263 |
-
for qa in qas:
|
| 264 |
-
logger.info("Question: %s", qa["question"])
|
| 265 |
-
logger.info("Answer: %s", qa["answer"])
|
| 266 |
-
final_results[compute_content_hash(qa["question"])] = {
|
| 267 |
-
"question": qa["question"],
|
| 268 |
-
"answer": qa["answer"],
|
| 269 |
-
"loss": get_average_loss(
|
| 270 |
-
_process_batch, traverse_strategy["loss_strategy"]
|
| 271 |
-
),
|
| 272 |
-
}
|
| 273 |
-
return final_results
|
| 274 |
-
|
| 275 |
-
results = {}
|
| 276 |
-
edges = list(await graph_storage.get_all_edges())
|
| 277 |
-
nodes = list(await graph_storage.get_all_nodes())
|
| 278 |
-
|
| 279 |
-
edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
|
| 280 |
-
|
| 281 |
-
processing_batches = await get_batches_with_strategy(
|
| 282 |
-
nodes, edges, graph_storage, traverse_strategy
|
| 283 |
-
)
|
| 284 |
-
|
| 285 |
-
for result in tqdm_async(
|
| 286 |
-
asyncio.as_completed(
|
| 287 |
-
[_process_single_batch(batch) for batch in processing_batches]
|
| 288 |
-
),
|
| 289 |
-
total=len(processing_batches),
|
| 290 |
-
desc="[4/4]Generating QAs",
|
| 291 |
-
):
|
| 292 |
-
try:
|
| 293 |
-
if progress_bar is not None:
|
| 294 |
-
progress_bar(
|
| 295 |
-
len(results) / len(processing_batches), desc="[4/4]Generating QAs"
|
| 296 |
-
)
|
| 297 |
-
results.update(await result)
|
| 298 |
-
if progress_bar is not None and len(results) == len(processing_batches):
|
| 299 |
-
progress_bar(1, desc="[4/4]Generating QAs")
|
| 300 |
-
except Exception as e: # pylint: disable=broad-except
|
| 301 |
-
logger.error("Error occurred while generating QA: %s", e)
|
| 302 |
-
|
| 303 |
-
return results
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
# pylint: disable=too-many-branches, too-many-statements
|
| 307 |
-
async def traverse_graph_for_atomic(
|
| 308 |
-
llm_client: OpenAIClient,
|
| 309 |
-
tokenizer: Tokenizer,
|
| 310 |
-
graph_storage: NetworkXStorage,
|
| 311 |
-
traverse_strategy: Dict,
|
| 312 |
-
text_chunks_storage: JsonKVStorage,
|
| 313 |
-
progress_bar: gr.Progress = None,
|
| 314 |
-
max_concurrent: int = 1000,
|
| 315 |
-
) -> dict:
|
| 316 |
-
"""
|
| 317 |
-
Traverse the graph atomicly
|
| 318 |
-
|
| 319 |
-
:param llm_client
|
| 320 |
-
:param tokenizer
|
| 321 |
-
:param graph_storage
|
| 322 |
-
:param traverse_strategy
|
| 323 |
-
:param text_chunks_storage
|
| 324 |
-
:param progress_bar
|
| 325 |
-
:param max_concurrent
|
| 326 |
-
:return: question and answer
|
| 327 |
-
"""
|
| 328 |
-
|
| 329 |
-
semaphore = asyncio.Semaphore(max_concurrent)
|
| 330 |
-
|
| 331 |
-
def _parse_qa(qa: str) -> tuple:
|
| 332 |
-
if "Question:" in qa and "Answer:" in qa:
|
| 333 |
-
question = qa.split("Question:")[1].split("Answer:")[0].strip()
|
| 334 |
-
answer = qa.split("Answer:")[1].strip()
|
| 335 |
-
elif "问题:" in qa and "答案:" in qa:
|
| 336 |
-
question = qa.split("问题:")[1].split("答案:")[0].strip()
|
| 337 |
-
answer = qa.split("答案:")[1].strip()
|
| 338 |
-
else:
|
| 339 |
-
return None, None
|
| 340 |
-
return question.strip('"'), answer.strip('"')
|
| 341 |
-
|
| 342 |
-
async def _generate_question(node_or_edge: tuple):
|
| 343 |
-
if len(node_or_edge) == 2:
|
| 344 |
-
des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
|
| 345 |
-
loss = node_or_edge[1]["loss"] if "loss" in node_or_edge[1] else -1.0
|
| 346 |
-
else:
|
| 347 |
-
des = node_or_edge[2]["description"]
|
| 348 |
-
loss = node_or_edge[2]["loss"] if "loss" in node_or_edge[2] else -1.0
|
| 349 |
-
|
| 350 |
-
async with semaphore:
|
| 351 |
-
try:
|
| 352 |
-
language = "Chinese" if detect_main_language(des) == "zh" else "English"
|
| 353 |
-
|
| 354 |
-
qa = await llm_client.generate_answer(
|
| 355 |
-
QUESTION_GENERATION_PROMPT[language]["SINGLE_QA_TEMPLATE"].format(
|
| 356 |
-
doc=des
|
| 357 |
-
)
|
| 358 |
-
)
|
| 359 |
-
|
| 360 |
-
question, answer = _parse_qa(qa)
|
| 361 |
-
if question is None or answer is None:
|
| 362 |
-
return {}
|
| 363 |
-
|
| 364 |
-
question = question.strip('"')
|
| 365 |
-
answer = answer.strip('"')
|
| 366 |
-
|
| 367 |
-
logger.info("Question: %s", question)
|
| 368 |
-
logger.info("Answer: %s", answer)
|
| 369 |
-
return {
|
| 370 |
-
compute_content_hash(question): {
|
| 371 |
-
"question": question,
|
| 372 |
-
"answer": answer,
|
| 373 |
-
"loss": loss,
|
| 374 |
-
}
|
| 375 |
-
}
|
| 376 |
-
except Exception as e: # pylint: disable=broad-except
|
| 377 |
-
logger.error("Error occurred while generating question: %s", e)
|
| 378 |
-
return {}
|
| 379 |
-
|
| 380 |
-
results = {}
|
| 381 |
-
edges = list(await graph_storage.get_all_edges())
|
| 382 |
-
nodes = list(await graph_storage.get_all_nodes())
|
| 383 |
-
|
| 384 |
-
edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
|
| 385 |
-
|
| 386 |
-
tasks = []
|
| 387 |
-
for node in nodes:
|
| 388 |
-
if "<SEP>" in node[1]["description"]:
|
| 389 |
-
description_list = node[1]["description"].split("<SEP>")
|
| 390 |
-
for item in description_list:
|
| 391 |
-
tasks.append((node[0], {"description": item}))
|
| 392 |
-
if "loss" in node[1]:
|
| 393 |
-
tasks[-1][1]["loss"] = node[1]["loss"]
|
| 394 |
-
else:
|
| 395 |
-
tasks.append((node[0], node[1]))
|
| 396 |
-
for edge in edges:
|
| 397 |
-
if "<SEP>" in edge[2]["description"]:
|
| 398 |
-
description_list = edge[2]["description"].split("<SEP>")
|
| 399 |
-
for item in description_list:
|
| 400 |
-
tasks.append((edge[0], edge[1], {"description": item}))
|
| 401 |
-
if "loss" in edge[2]:
|
| 402 |
-
tasks[-1][2]["loss"] = edge[2]["loss"]
|
| 403 |
-
else:
|
| 404 |
-
tasks.append((edge[0], edge[1], edge[2]))
|
| 405 |
-
|
| 406 |
-
for result in tqdm_async(
|
| 407 |
-
asyncio.as_completed([_generate_question(task) for task in tasks]),
|
| 408 |
-
total=len(tasks),
|
| 409 |
-
desc="[4/4]Generating QAs",
|
| 410 |
-
):
|
| 411 |
-
try:
|
| 412 |
-
if progress_bar is not None:
|
| 413 |
-
progress_bar(len(results) / len(tasks), desc="[4/4]Generating QAs")
|
| 414 |
-
results.update(await result)
|
| 415 |
-
if progress_bar is not None and len(results) == len(tasks):
|
| 416 |
-
progress_bar(1, desc="[4/4]Generating QAs")
|
| 417 |
-
except Exception as e: # pylint: disable=broad-except
|
| 418 |
-
logger.error("Error occurred while generating QA: %s", e)
|
| 419 |
-
return results
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
async def traverse_graph_for_multi_hop(
|
| 423 |
-
llm_client: OpenAIClient,
|
| 424 |
-
tokenizer: Tokenizer,
|
| 425 |
-
graph_storage: NetworkXStorage,
|
| 426 |
-
traverse_strategy: Dict,
|
| 427 |
-
text_chunks_storage: JsonKVStorage,
|
| 428 |
-
progress_bar: gr.Progress = None,
|
| 429 |
-
max_concurrent: int = 1000,
|
| 430 |
-
) -> dict:
|
| 431 |
-
"""
|
| 432 |
-
Traverse the graph for multi-hop
|
| 433 |
-
|
| 434 |
-
:param llm_client
|
| 435 |
-
:param tokenizer
|
| 436 |
-
:param graph_storage
|
| 437 |
-
:param traverse_strategy
|
| 438 |
-
:param text_chunks_storage
|
| 439 |
-
:param progress_bar
|
| 440 |
-
:param max_concurrent
|
| 441 |
-
:return: question and answer
|
| 442 |
-
"""
|
| 443 |
-
semaphore = asyncio.Semaphore(max_concurrent)
|
| 444 |
-
|
| 445 |
-
results = {}
|
| 446 |
-
edges = list(await graph_storage.get_all_edges())
|
| 447 |
-
nodes = list(await graph_storage.get_all_nodes())
|
| 448 |
-
|
| 449 |
-
edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
|
| 450 |
-
|
| 451 |
-
processing_batches = await get_batches_with_strategy(
|
| 452 |
-
nodes, edges, graph_storage, traverse_strategy
|
| 453 |
-
)
|
| 454 |
-
|
| 455 |
-
async def _process_single_batch(_process_batch: tuple) -> dict:
|
| 456 |
-
async with semaphore:
|
| 457 |
-
try:
|
| 458 |
-
language = (
|
| 459 |
-
"Chinese"
|
| 460 |
-
if detect_main_language(_process_batch[0][0]["description"]) == "zh"
|
| 461 |
-
else "English"
|
| 462 |
-
)
|
| 463 |
-
|
| 464 |
-
_process_nodes = _process_batch[0]
|
| 465 |
-
_process_edges = _process_batch[1]
|
| 466 |
-
|
| 467 |
-
entities = [
|
| 468 |
-
f"{_process_node['node_id']}: {_process_node['description']}"
|
| 469 |
-
for _process_node in _process_nodes
|
| 470 |
-
]
|
| 471 |
-
|
| 472 |
-
relations = [
|
| 473 |
-
f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
|
| 474 |
-
for _process_edge in _process_edges
|
| 475 |
-
]
|
| 476 |
-
|
| 477 |
-
entities_str = "\n".join(
|
| 478 |
-
[f"{index + 1}. {entity}" for index, entity in enumerate(entities)]
|
| 479 |
-
)
|
| 480 |
-
relations_str = "\n".join(
|
| 481 |
-
[
|
| 482 |
-
f"{index + 1}. {relation}"
|
| 483 |
-
for index, relation in enumerate(relations)
|
| 484 |
-
]
|
| 485 |
-
)
|
| 486 |
-
|
| 487 |
-
prompt = MULTI_HOP_GENERATION_PROMPT[language].format(
|
| 488 |
-
entities=entities_str, relationships=relations_str
|
| 489 |
-
)
|
| 490 |
-
|
| 491 |
-
context = await llm_client.generate_answer(prompt)
|
| 492 |
-
|
| 493 |
-
# post-process the context
|
| 494 |
-
if "Question:" in context and "Answer:" in context:
|
| 495 |
-
question = context.split("Question:")[1].split("Answer:")[0].strip()
|
| 496 |
-
answer = context.split("Answer:")[1].strip()
|
| 497 |
-
elif "问题:" in context and "答案:" in context:
|
| 498 |
-
question = context.split("问题:")[1].split("答案:")[0].strip()
|
| 499 |
-
answer = context.split("答案:")[1].strip()
|
| 500 |
-
else:
|
| 501 |
-
return {}
|
| 502 |
-
|
| 503 |
-
question = question.strip('"')
|
| 504 |
-
answer = answer.strip('"')
|
| 505 |
-
|
| 506 |
-
logger.info("Question: %s", question)
|
| 507 |
-
logger.info("Answer: %s", answer)
|
| 508 |
-
|
| 509 |
-
return {
|
| 510 |
-
compute_content_hash(question): {
|
| 511 |
-
"question": question,
|
| 512 |
-
"answer": answer,
|
| 513 |
-
"loss": get_average_loss(
|
| 514 |
-
_process_batch, traverse_strategy["loss_strategy"]
|
| 515 |
-
),
|
| 516 |
-
}
|
| 517 |
-
}
|
| 518 |
-
|
| 519 |
-
except Exception as e: # pylint: disable=broad-except
|
| 520 |
-
logger.error("Error occurred while processing batch: %s", e)
|
| 521 |
-
return {}
|
| 522 |
-
|
| 523 |
-
async for result in tqdm_async(
|
| 524 |
-
asyncio.as_completed(
|
| 525 |
-
[_process_single_batch(batch) for batch in processing_batches]
|
| 526 |
-
),
|
| 527 |
-
total=len(processing_batches),
|
| 528 |
-
desc="[4/4]Generating QAs",
|
| 529 |
-
):
|
| 530 |
-
try:
|
| 531 |
-
if progress_bar is not None:
|
| 532 |
-
progress_bar(
|
| 533 |
-
len(results) / len(processing_batches), desc="[4/4]Generating QAs"
|
| 534 |
-
)
|
| 535 |
-
results.update(await result)
|
| 536 |
-
if progress_bar is not None and len(results) == len(processing_batches):
|
| 537 |
-
progress_bar(1, desc="[4/4]Generating QAs")
|
| 538 |
-
except Exception as e: # pylint: disable=broad-except
|
| 539 |
-
logger.error("Error occurred while generating QA: %s", e)
|
| 540 |
-
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/templates/__init__.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
| 1 |
-
from .answer_rephrasing import ANSWER_REPHRASING_PROMPT
|
| 2 |
-
from .community import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
|
| 3 |
from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT
|
| 4 |
from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from .kg_extraction import KG_EXTRACTION_PROMPT
|
| 6 |
from .kg_summarization import KG_SUMMARIZATION_PROMPT
|
| 7 |
-
from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT
|
| 8 |
from .question_generation import QUESTION_GENERATION_PROMPT
|
| 9 |
from .search_judgement import SEARCH_JUDGEMENT_PROMPT
|
| 10 |
from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT
|
|
|
|
|
|
|
|
|
|
| 1 |
from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT
|
| 2 |
from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
|
| 3 |
+
from .generation import (
|
| 4 |
+
AGGREGATED_GENERATION_PROMPT,
|
| 5 |
+
ATOMIC_GENERATION_PROMPT,
|
| 6 |
+
COT_GENERATION_PROMPT,
|
| 7 |
+
MULTI_HOP_GENERATION_PROMPT,
|
| 8 |
+
)
|
| 9 |
from .kg_extraction import KG_EXTRACTION_PROMPT
|
| 10 |
from .kg_summarization import KG_SUMMARIZATION_PROMPT
|
|
|
|
| 11 |
from .question_generation import QUESTION_GENERATION_PROMPT
|
| 12 |
from .search_judgement import SEARCH_JUDGEMENT_PROMPT
|
| 13 |
from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT
|
graphgen/templates/community/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
from .cot_generation import COT_GENERATION_PROMPT
|
| 2 |
-
from .cot_template_design import COT_TEMPLATE_DESIGN_PROMPT
|
|
|
|
|
|
|
|
|
graphgen/templates/community/cot_generation.py
DELETED
|
@@ -1,87 +0,0 @@
|
|
| 1 |
-
TEMPLATE_ZH = """根据给定的知识图谱原始信息及已生成的推理路径,产出一条符合模板要求、可直接用于下游训练或推理的 CoT 数据。\
|
| 2 |
-
CoT(Chain-of-Thought,思维链)指在回答复杂问题时,把中间推理步骤一步一步显式写出来,使推理过程透明、可追溯,而不是直接给出最终答案。
|
| 3 |
-
|
| 4 |
-
-输入格式-
|
| 5 |
-
[Entities:]
|
| 6 |
-
(实体名:实体描述)
|
| 7 |
-
...
|
| 8 |
-
|
| 9 |
-
[Relationships:]
|
| 10 |
-
(来源实体)-[关系描述]->(目标实体)
|
| 11 |
-
...
|
| 12 |
-
|
| 13 |
-
[Question and Reasoning Path:]
|
| 14 |
-
(问题)
|
| 15 |
-
(推理路径)
|
| 16 |
-
|
| 17 |
-
-输出要求-
|
| 18 |
-
1. 每一步只完成一个不可分割的子任务,并用自然语言衔接,但是要避免生硬的连接词。
|
| 19 |
-
2. 使用中文。
|
| 20 |
-
3. 不要使用有序列表或编号。
|
| 21 |
-
4. 请直接给出答案,不要生成无关信息。
|
| 22 |
-
|
| 23 |
-
-真实数据-
|
| 24 |
-
输入:
|
| 25 |
-
[Entities:]:
|
| 26 |
-
{entities}
|
| 27 |
-
|
| 28 |
-
[Relationships:]:
|
| 29 |
-
{relationships}
|
| 30 |
-
|
| 31 |
-
[Question:]:
|
| 32 |
-
{question}
|
| 33 |
-
|
| 34 |
-
[Reasoning_Template:]:
|
| 35 |
-
{reasoning_template}
|
| 36 |
-
|
| 37 |
-
输出:
|
| 38 |
-
|
| 39 |
-
"""
|
| 40 |
-
|
| 41 |
-
TEMPLATE_EN = """Given the raw knowledge graph information and the provided reasoning-path, \
|
| 42 |
-
produce one Chain-of-Thought (CoT) sample that strictly follows the template \
|
| 43 |
-
and can be directly used for downstream training or inference.
|
| 44 |
-
CoT (Chain-of-Thought) means that when answering a complex question, the intermediate reasoning steps are \
|
| 45 |
-
explicitly written out one by one, making the reasoning process transparent and traceable instead of giving \
|
| 46 |
-
only the final answer.
|
| 47 |
-
|
| 48 |
-
-Input Format-
|
| 49 |
-
[Entities:]:
|
| 50 |
-
(ENTITY_NAME: ENTITY_DESCRIPTION)
|
| 51 |
-
...
|
| 52 |
-
|
| 53 |
-
[Relationships:]:
|
| 54 |
-
(ENTITY_SOURCE)-[RELATIONSHIP_DESCRIPTION]->(ENTITY_TARGET)
|
| 55 |
-
...
|
| 56 |
-
|
| 57 |
-
[Question and Reasoning Path:]:
|
| 58 |
-
(QUESTION)
|
| 59 |
-
(REASONING_PATH)
|
| 60 |
-
|
| 61 |
-
-Output Requirements-
|
| 62 |
-
1. Each step completes a single, indivisible sub-task and is naturally connected, avoiding abrupt transition words.
|
| 63 |
-
2. Use English.
|
| 64 |
-
3. Do not use ordered lists or numbering.
|
| 65 |
-
4. Do not generate extraneous information, just provide the answer.
|
| 66 |
-
|
| 67 |
-
-Real Data-
|
| 68 |
-
Input:
|
| 69 |
-
[Entities:]:
|
| 70 |
-
{entities}
|
| 71 |
-
|
| 72 |
-
[Relationships:]:
|
| 73 |
-
{relationships}
|
| 74 |
-
|
| 75 |
-
[Question:]:
|
| 76 |
-
{question}
|
| 77 |
-
|
| 78 |
-
[Reasoning_Template:]:
|
| 79 |
-
{reasoning_template}
|
| 80 |
-
|
| 81 |
-
Output:
|
| 82 |
-
"""
|
| 83 |
-
|
| 84 |
-
COT_GENERATION_PROMPT = {
|
| 85 |
-
"Chinese": {"TEMPLATE": TEMPLATE_ZH},
|
| 86 |
-
"English": {"TEMPLATE": TEMPLATE_EN},
|
| 87 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|