github-actions[bot] commited on
Commit
799ac7c
·
1 Parent(s): 37f0321

Auto-sync from demo at Wed Oct 15 06:28:02 UTC 2025

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +1 -1
  2. graphgen/bases/__init__.py +2 -0
  3. graphgen/bases/base_generator.py +84 -0
  4. graphgen/bases/base_partitioner.py +76 -0
  5. graphgen/bases/base_storage.py +2 -2
  6. graphgen/bases/datatypes.py +8 -0
  7. graphgen/configs/aggregated_config.yaml +4 -8
  8. graphgen/configs/atomic_config.yaml +2 -9
  9. graphgen/configs/cot_config.yaml +3 -3
  10. graphgen/configs/multi_hop_config.yaml +4 -8
  11. graphgen/graphgen.py +14 -52
  12. graphgen/models/__init__.py +15 -8
  13. graphgen/models/community/__init__.py +0 -0
  14. graphgen/models/community/community_detector.py +0 -95
  15. graphgen/models/evaluate/__init__.py +0 -0
  16. graphgen/models/evaluator/__init__.py +4 -0
  17. graphgen/models/{evaluate → evaluator}/base_evaluator.py +0 -0
  18. graphgen/models/{evaluate → evaluator}/length_evaluator.py +1 -1
  19. graphgen/models/{evaluate → evaluator}/mtld_evaluator.py +1 -1
  20. graphgen/models/{evaluate → evaluator}/reward_evaluator.py +0 -0
  21. graphgen/models/{evaluate → evaluator}/uni_evaluator.py +0 -0
  22. graphgen/models/generator/__init__.py +4 -0
  23. graphgen/models/generator/aggregated_generator.py +127 -0
  24. graphgen/models/generator/atomic_generator.py +52 -0
  25. graphgen/models/generator/cot_generator.py +122 -0
  26. graphgen/models/generator/multi_hop_generator.py +55 -0
  27. graphgen/models/kg_builder/__init__.py +1 -0
  28. graphgen/models/llm/limitter.py +27 -29
  29. graphgen/models/llm/openai_client.py +4 -2
  30. graphgen/models/partitioner/__init__.py +4 -0
  31. graphgen/models/partitioner/bfs_partitioner.py +83 -0
  32. graphgen/models/partitioner/dfs_partitioner.py +80 -0
  33. graphgen/models/partitioner/ece_partitioner.py +163 -0
  34. graphgen/models/partitioner/leiden_partitioner.py +120 -0
  35. graphgen/models/storage/__init__.py +2 -0
  36. graphgen/models/storage/networkx_storage.py +4 -4
  37. graphgen/operators/__init__.py +4 -9
  38. graphgen/operators/build_kg/__init__.py +1 -0
  39. graphgen/operators/build_kg/split_kg.py +0 -382
  40. graphgen/operators/generate/__init__.py +1 -0
  41. graphgen/operators/generate/generate_cot.py +0 -117
  42. graphgen/operators/generate/generate_qas.py +58 -0
  43. graphgen/operators/partition/__init__.py +1 -0
  44. graphgen/operators/partition/partition_kg.py +48 -0
  45. graphgen/operators/partition/pre_tokenize.py +47 -0
  46. graphgen/operators/search/__init__.py +1 -0
  47. graphgen/operators/traverse_graph.py +0 -540
  48. graphgen/templates/__init__.py +6 -3
  49. graphgen/templates/community/__init__.py +0 -2
  50. graphgen/templates/community/cot_generation.py +0 -87
app.py CHANGED
@@ -468,7 +468,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
468
  label="TPM",
469
  minimum=5000,
470
  maximum=5000000,
471
- value=50000,
472
  step=1000,
473
  interactive=True,
474
  visible=True,
 
468
  label="TPM",
469
  minimum=5000,
470
  maximum=5000000,
471
+ value=100000,
472
  step=1000,
473
  interactive=True,
474
  visible=True,
graphgen/bases/__init__.py CHANGED
@@ -1,5 +1,7 @@
 
1
  from .base_kg_builder import BaseKGBuilder
2
  from .base_llm_client import BaseLLMClient
 
3
  from .base_reader import BaseReader
4
  from .base_splitter import BaseSplitter
5
  from .base_storage import (
 
1
+ from .base_generator import BaseGenerator
2
  from .base_kg_builder import BaseKGBuilder
3
  from .base_llm_client import BaseLLMClient
4
+ from .base_partitioner import BasePartitioner
5
  from .base_reader import BaseReader
6
  from .base_splitter import BaseSplitter
7
  from .base_storage import (
graphgen/bases/base_generator.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Any
4
+
5
+ from graphgen.bases.base_llm_client import BaseLLMClient
6
+
7
+
8
+ @dataclass
9
+ class BaseGenerator(ABC):
10
+ """
11
+ Generate QAs based on given prompts.
12
+ """
13
+
14
+ llm_client: BaseLLMClient
15
+
16
+ @staticmethod
17
+ @abstractmethod
18
+ def build_prompt(
19
+ batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
20
+ ) -> str:
21
+ """Build prompt for LLM based on the given batch"""
22
+
23
+ @staticmethod
24
+ @abstractmethod
25
+ def parse_response(response: str) -> Any:
26
+ """Parse the LLM response and return the generated QAs"""
27
+
28
+ async def generate(
29
+ self,
30
+ batch: tuple[
31
+ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
32
+ ],
33
+ ) -> dict[str, Any]:
34
+ """
35
+ Generate QAs based on a given batch.
36
+ :param batch
37
+ :return: QA pairs
38
+ """
39
+ result = {}
40
+ prompt = self.build_prompt(batch)
41
+ response = await self.llm_client.generate_answer(prompt)
42
+ qa_pairs = self.parse_response(response) # generate one or more QA pairs
43
+ result.update(qa_pairs)
44
+ return result
45
+
46
+ @staticmethod
47
+ def format_generation_results(
48
+ results: list[dict], output_data_format: str
49
+ ) -> list[dict[str, Any]]:
50
+ if output_data_format == "Alpaca":
51
+ results = [
52
+ {
53
+ "instruction": v["question"],
54
+ "input": "",
55
+ "output": v["answer"],
56
+ }
57
+ for item in results
58
+ for k, v in item.items()
59
+ ]
60
+ elif output_data_format == "Sharegpt":
61
+ results = [
62
+ {
63
+ "conversations": [
64
+ {"from": "human", "value": v["question"]},
65
+ {"from": "gpt", "value": v["answer"]},
66
+ ]
67
+ }
68
+ for item in results
69
+ for k, v in item.items()
70
+ ]
71
+ elif output_data_format == "ChatML":
72
+ results = [
73
+ {
74
+ "messages": [
75
+ {"role": "user", "content": v["question"]},
76
+ {"role": "assistant", "content": v["answer"]},
77
+ ]
78
+ }
79
+ for item in results
80
+ for k, v in item.items()
81
+ ]
82
+ else:
83
+ raise ValueError(f"Unknown output data format: {output_data_format}")
84
+ return results
graphgen/bases/base_partitioner.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from typing import Any, List
4
+
5
+ from graphgen.bases.base_storage import BaseGraphStorage
6
+ from graphgen.bases.datatypes import Community
7
+
8
+
9
+ @dataclass
10
+ class BasePartitioner(ABC):
11
+ @abstractmethod
12
+ async def partition(
13
+ self,
14
+ g: BaseGraphStorage,
15
+ **kwargs: Any,
16
+ ) -> List[Community]:
17
+ """
18
+ Graph -> Communities
19
+ :param g: Graph storage instance
20
+ :param kwargs: Additional parameters for partitioning
21
+ :return: List of communities
22
+ """
23
+
24
+ @staticmethod
25
+ async def community2batch(
26
+ communities: List[Community], g: BaseGraphStorage
27
+ ) -> list[
28
+ tuple[
29
+ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
30
+ ]
31
+ ]:
32
+ """
33
+ Convert communities to batches of nodes and edges.
34
+ :param communities
35
+ :param g: Graph storage instance
36
+ :return: List of batches, each batch is a tuple of (nodes, edges)
37
+ """
38
+ batches = []
39
+ for comm in communities:
40
+ nodes = comm.nodes
41
+ edges = comm.edges
42
+ nodes_data = []
43
+ for node in nodes:
44
+ node_data = await g.get_node(node)
45
+ if node_data:
46
+ nodes_data.append((node, node_data))
47
+ edges_data = []
48
+ for u, v in edges:
49
+ edge_data = await g.get_edge(u, v)
50
+ if edge_data:
51
+ edges_data.append((u, v, edge_data))
52
+ else:
53
+ edge_data = await g.get_edge(v, u)
54
+ if edge_data:
55
+ edges_data.append((v, u, edge_data))
56
+ batches.append((nodes_data, edges_data))
57
+ return batches
58
+
59
+ @staticmethod
60
+ def _build_adjacency_list(
61
+ nodes: List[tuple[str, dict]], edges: List[tuple[str, str, dict]]
62
+ ) -> tuple[dict[str, List[str]], set[tuple[str, str]]]:
63
+ """
64
+ Build adjacency list and edge set from nodes and edges.
65
+ :param nodes
66
+ :param edges
67
+ :return: adjacency list, edge set
68
+ """
69
+ adj: dict[str, List[str]] = {n[0]: [] for n in nodes}
70
+ edge_set: set[tuple[str, str]] = set()
71
+ for e in edges:
72
+ adj[e[0]].append(e[1])
73
+ adj[e[1]].append(e[0])
74
+ edge_set.add((e[0], e[1]))
75
+ edge_set.add((e[1], e[0]))
76
+ return adj, edge_set
graphgen/bases/base_storage.py CHANGED
@@ -78,7 +78,7 @@ class BaseGraphStorage(StorageNameSpace):
78
  async def update_node(self, node_id: str, node_data: dict[str, str]):
79
  raise NotImplementedError
80
 
81
- async def get_all_nodes(self) -> Union[list[dict], None]:
82
  raise NotImplementedError
83
 
84
  async def get_edge(
@@ -91,7 +91,7 @@ class BaseGraphStorage(StorageNameSpace):
91
  ):
92
  raise NotImplementedError
93
 
94
- async def get_all_edges(self) -> Union[list[dict], None]:
95
  raise NotImplementedError
96
 
97
  async def get_node_edges(
 
78
  async def update_node(self, node_id: str, node_data: dict[str, str]):
79
  raise NotImplementedError
80
 
81
+ async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
82
  raise NotImplementedError
83
 
84
  async def get_edge(
 
91
  ):
92
  raise NotImplementedError
93
 
94
+ async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
95
  raise NotImplementedError
96
 
97
  async def get_node_edges(
graphgen/bases/datatypes.py CHANGED
@@ -30,3 +30,11 @@ class Token:
30
  @property
31
  def logprob(self) -> float:
32
  return math.log(self.prob)
 
 
 
 
 
 
 
 
 
30
  @property
31
  def logprob(self) -> float:
32
  return math.log(self.prob)
33
+
34
+
35
+ @dataclass
36
+ class Community:
37
+ id: Union[int, str]
38
+ nodes: List[str] = field(default_factory=list)
39
+ edges: List[tuple] = field(default_factory=list)
40
+ metadata: dict = field(default_factory=dict)
graphgen/configs/aggregated_config.yaml CHANGED
@@ -13,14 +13,10 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
13
  partition: # graph partition configuration
14
  method: ece # ece is a custom partition method based on comprehension loss
15
  method_params:
16
- bidirectional: true # whether to traverse the graph in both directions
17
- edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18
- expand_method: max_width # expand method, support: max_width, max_depth
19
- isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20
- max_depth: 5 # maximum depth for graph traversal
21
- max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
22
- max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23
- loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
24
  generate:
25
  mode: aggregated # atomic, aggregated, multi_hop, cot
26
  data_format: ChatML # Alpaca, Sharegpt, ChatML
 
13
  partition: # graph partition configuration
14
  method: ece # ece is a custom partition method based on comprehension loss
15
  method_params:
16
+ max_units_per_community: 20 # max nodes and edges per community
17
+ min_units_per_community: 5 # min nodes and edges per community
18
+ max_tokens_per_community: 10240 # max tokens per community
19
+ unit_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
 
 
 
 
20
  generate:
21
  mode: aggregated # atomic, aggregated, multi_hop, cot
22
  data_format: ChatML # Alpaca, Sharegpt, ChatML
graphgen/configs/atomic_config.yaml CHANGED
@@ -11,16 +11,9 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
11
  quiz_samples: 2 # number of quiz samples to generate
12
  re_judge: false # whether to re-judge the existing quiz samples
13
  partition: # graph partition configuration
14
- method: ece # ece is a custom partition method based on comprehension loss
15
  method_params:
16
- bidirectional: true # whether to traverse the graph in both directions
17
- edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18
- expand_method: max_width # expand method, support: max_width, max_depth
19
- isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20
- max_depth: 3 # maximum depth for graph traversal
21
- max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
22
- max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23
- loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
24
  generate:
25
  mode: atomic # atomic, aggregated, multi_hop, cot
26
  data_format: Alpaca # Alpaca, Sharegpt, ChatML
 
11
  quiz_samples: 2 # number of quiz samples to generate
12
  re_judge: false # whether to re-judge the existing quiz samples
13
  partition: # graph partition configuration
14
+ method: dfs # partition method, support: dfs, bfs, ece, leiden
15
  method_params:
16
+ max_units_per_community: 1 # atomic partition, one node or edge per community
 
 
 
 
 
 
 
17
  generate:
18
  mode: atomic # atomic, aggregated, multi_hop, cot
19
  data_format: Alpaca # Alpaca, Sharegpt, ChatML
graphgen/configs/cot_config.yaml CHANGED
@@ -9,11 +9,11 @@ search: # web search configuration
9
  quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
10
  enabled: false
11
  partition: # graph partition configuration
12
- method: leiden # leiden is a community detection algorithm
13
  method_params:
14
  max_size: 20 # Maximum size of communities
15
- use_lcc: false
16
- random_seed: 42
17
  generate:
18
  mode: cot # atomic, aggregated, multi_hop, cot
19
  data_format: Sharegpt # Alpaca, Sharegpt, ChatML
 
9
  quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
10
  enabled: false
11
  partition: # graph partition configuration
12
+ method: leiden # leiden is a partitioner detection algorithm
13
  method_params:
14
  max_size: 20 # Maximum size of communities
15
+ use_lcc: false # whether to use the largest connected component
16
+ random_seed: 42 # random seed for partitioning
17
  generate:
18
  mode: cot # atomic, aggregated, multi_hop, cot
19
  data_format: Sharegpt # Alpaca, Sharegpt, ChatML
graphgen/configs/multi_hop_config.yaml CHANGED
@@ -13,14 +13,10 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
13
  partition: # graph partition configuration
14
  method: ece # ece is a custom partition method based on comprehension loss
15
  method_params:
16
- bidirectional: true # whether to traverse the graph in both directions
17
- edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18
- expand_method: max_width # expand method, support: max_width, max_depth
19
- isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20
- max_depth: 1 # maximum depth for graph traversal
21
- max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
22
- max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23
- loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
24
  generate:
25
  mode: multi_hop # strategy for generating multi-hop QA pairs
26
  data_format: ChatML # Alpaca, Sharegpt, ChatML
 
13
  partition: # graph partition configuration
14
  method: ece # ece is a custom partition method based on comprehension loss
15
  method_params:
16
+ max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3
17
+ min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3
18
+ max_tokens_per_community: 10240 # max tokens per community
19
+ unit_sampling: random # edge sampling strategy, support: random, max_loss, min_loss
 
 
 
 
20
  generate:
21
  mode: multi_hop # strategy for generating multi-hop QA pairs
22
  data_format: ChatML # Alpaca, Sharegpt, ChatML
graphgen/graphgen.py CHANGED
@@ -18,21 +18,14 @@ from graphgen.models import (
18
  from graphgen.operators import (
19
  build_kg,
20
  chunk_documents,
21
- generate_cot,
22
  judge_statement,
 
23
  quiz,
24
  read_files,
25
  search_all,
26
- traverse_graph_for_aggregated,
27
- traverse_graph_for_atomic,
28
- traverse_graph_for_multi_hop,
29
- )
30
- from graphgen.utils import (
31
- async_to_sync_method,
32
- compute_content_hash,
33
- format_generation_results,
34
- logger,
35
  )
 
36
 
37
  sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
38
 
@@ -238,51 +231,20 @@ class GraphGen:
238
  @async_to_sync_method
239
  async def generate(self, partition_config: Dict, generate_config: Dict):
240
  # Step 1: partition the graph
241
- # TODO: implement graph partitioning, e.g. Partitioner().partition(self.graph_storage)
242
- mode = generate_config["mode"]
243
- if mode == "atomic":
244
- results = await traverse_graph_for_atomic(
245
- self.synthesizer_llm_client,
246
- self.tokenizer_instance,
247
- self.graph_storage,
248
- partition_config["method_params"],
249
- self.text_chunks_storage,
250
- self.progress_bar,
251
- )
252
- elif mode == "multi_hop":
253
- results = await traverse_graph_for_multi_hop(
254
- self.synthesizer_llm_client,
255
- self.tokenizer_instance,
256
- self.graph_storage,
257
- partition_config["method_params"],
258
- self.text_chunks_storage,
259
- self.progress_bar,
260
- )
261
- elif mode == "aggregated":
262
- results = await traverse_graph_for_aggregated(
263
- self.synthesizer_llm_client,
264
- self.tokenizer_instance,
265
- self.graph_storage,
266
- partition_config["method_params"],
267
- self.text_chunks_storage,
268
- self.progress_bar,
269
- )
270
- elif mode == "cot":
271
- results = await generate_cot(
272
- self.graph_storage,
273
- self.synthesizer_llm_client,
274
- method_params=partition_config["method_params"],
275
- )
276
- else:
277
- raise ValueError(f"Unknown generation mode: {mode}")
278
- # Step 2: generate QA pairs
279
- # TODO
280
 
281
- # Step 3: format
282
- results = format_generation_results(
283
- results, output_data_format=generate_config["data_format"]
284
  )
285
 
 
 
 
 
 
286
  await self.qa_storage.upsert(results)
287
  await self.qa_storage.index_done_callback()
288
 
 
18
  from graphgen.operators import (
19
  build_kg,
20
  chunk_documents,
21
+ generate_qas,
22
  judge_statement,
23
+ partition_kg,
24
  quiz,
25
  read_files,
26
  search_all,
 
 
 
 
 
 
 
 
 
27
  )
28
+ from graphgen.utils import async_to_sync_method, compute_content_hash, logger
29
 
30
  sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
31
 
 
231
  @async_to_sync_method
232
  async def generate(self, partition_config: Dict, generate_config: Dict):
233
  # Step 1: partition the graph
234
+ batches = await partition_kg(
235
+ self.graph_storage, self.tokenizer_instance, partition_config
236
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+ # Step 2: generate QA pairs
239
+ results = await generate_qas(
240
+ self.synthesizer_llm_client, batches, generate_config
241
  )
242
 
243
+ if not results:
244
+ logger.warning("No QA pairs generated")
245
+ return
246
+
247
+ # Step 3: store the generated QA pairs
248
  await self.qa_storage.upsert(results)
249
  await self.qa_storage.index_done_callback()
250
 
graphgen/models/__init__.py CHANGED
@@ -1,17 +1,24 @@
1
- from .community.community_detector import CommunityDetector
2
- from .evaluate.length_evaluator import LengthEvaluator
3
- from .evaluate.mtld_evaluator import MTLDEvaluator
4
- from .evaluate.reward_evaluator import RewardEvaluator
5
- from .evaluate.uni_evaluator import UniEvaluator
6
- from .kg_builder.light_rag_kg_builder import LightRAGKGBuilder
 
 
7
  from .llm.openai_client import OpenAIClient
8
  from .llm.topk_token_model import TopkTokenModel
 
 
 
 
 
 
9
  from .reader import CsvReader, JsonlReader, JsonReader, TxtReader
10
  from .search.db.uniprot_search import UniProtSearch
11
  from .search.kg.wiki_search import WikiSearch
12
  from .search.web.bing_search import BingSearch
13
  from .search.web.google_search import GoogleSearch
14
  from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
15
- from .storage.json_storage import JsonKVStorage, JsonListStorage
16
- from .storage.networkx_storage import NetworkXStorage
17
  from .tokenizer import Tokenizer
 
1
+ from .evaluator import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
2
+ from .generator import (
3
+ AggregatedGenerator,
4
+ AtomicGenerator,
5
+ CoTGenerator,
6
+ MultiHopGenerator,
7
+ )
8
+ from .kg_builder import LightRAGKGBuilder
9
  from .llm.openai_client import OpenAIClient
10
  from .llm.topk_token_model import TopkTokenModel
11
+ from .partitioner import (
12
+ BFSPartitioner,
13
+ DFSPartitioner,
14
+ ECEPartitioner,
15
+ LeidenPartitioner,
16
+ )
17
  from .reader import CsvReader, JsonlReader, JsonReader, TxtReader
18
  from .search.db.uniprot_search import UniProtSearch
19
  from .search.kg.wiki_search import WikiSearch
20
  from .search.web.bing_search import BingSearch
21
  from .search.web.google_search import GoogleSearch
22
  from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
23
+ from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage
 
24
  from .tokenizer import Tokenizer
graphgen/models/community/__init__.py DELETED
File without changes
graphgen/models/community/community_detector.py DELETED
@@ -1,95 +0,0 @@
1
- from collections import defaultdict
2
- from dataclasses import dataclass
3
- from typing import Any, Dict, List
4
-
5
- from graphgen.models.storage.networkx_storage import NetworkXStorage
6
-
7
-
8
- @dataclass
9
- class CommunityDetector:
10
- """Class for community detection algorithms."""
11
-
12
- graph_storage: NetworkXStorage = None
13
- method: str = "leiden"
14
- method_params: Dict[str, Any] = None
15
-
16
- async def detect_communities(self) -> Dict[str, int]:
17
- if self.method == "leiden":
18
- return await self._leiden_communities(**self.method_params or {})
19
- raise ValueError(f"Unknown community detection method: {self.method}")
20
-
21
- async def get_graph(self):
22
- return await self.graph_storage.get_graph()
23
-
24
- async def _leiden_communities(
25
- self, max_size: int = None, **kwargs
26
- ) -> Dict[str, int]:
27
- """
28
- Detect communities using the Leiden algorithm.
29
- If max_size is given, any community larger than max_size will be split
30
- into smaller sub-communities each having at most max_size nodes.
31
- """
32
- import igraph as ig
33
- import networkx as nx
34
- from leidenalg import ModularityVertexPartition, find_partition
35
-
36
- graph = await self.get_graph()
37
- graph.remove_nodes_from(list(nx.isolates(graph)))
38
-
39
- ig_graph = ig.Graph.TupleList(graph.edges(), directed=False)
40
-
41
- random_seed = kwargs.get("random_seed", 42)
42
- use_lcc = kwargs.get("use_lcc", False)
43
-
44
- communities: Dict[str, int] = {}
45
- if use_lcc:
46
- lcc = ig_graph.components().giant()
47
- partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed)
48
- for part, cluster in enumerate(partition):
49
- for v in cluster:
50
- communities[lcc.vs[v]["name"]] = part
51
- else:
52
- offset = 0
53
- for component in ig_graph.components():
54
- subgraph = ig_graph.induced_subgraph(component)
55
- partition = find_partition(
56
- subgraph, ModularityVertexPartition, seed=random_seed
57
- )
58
- for part, cluster in enumerate(partition):
59
- for v in cluster:
60
- original_node = subgraph.vs[v]["name"]
61
- communities[original_node] = part + offset
62
- offset += len(partition)
63
-
64
- # split large communities if max_size is specified
65
- if max_size is None or max_size <= 0:
66
- return communities
67
-
68
- return await self._split_communities(communities, max_size)
69
-
70
- @staticmethod
71
- async def _split_communities(
72
- communities: Dict[str, int], max_size: int
73
- ) -> Dict[str, int]:
74
- """
75
- Split communities larger than max_size into smaller sub-communities.
76
- """
77
- cid2nodes: Dict[int, List[str]] = defaultdict(list)
78
- for node, cid in communities.items():
79
- cid2nodes[cid].append(node)
80
-
81
- new_communities: Dict[str, int] = {}
82
- new_cid = 0
83
- for cid, nodes in cid2nodes.items():
84
- if len(nodes) <= max_size:
85
- for n in nodes:
86
- new_communities[n] = new_cid
87
- new_cid += 1
88
- else:
89
- for start in range(0, len(nodes), max_size):
90
- sub = nodes[start : start + max_size]
91
- for n in sub:
92
- new_communities[n] = new_cid
93
- new_cid += 1
94
-
95
- return new_communities
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/evaluate/__init__.py DELETED
File without changes
graphgen/models/evaluator/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .length_evaluator import LengthEvaluator
2
+ from .mtld_evaluator import MTLDEvaluator
3
+ from .reward_evaluator import RewardEvaluator
4
+ from .uni_evaluator import UniEvaluator
graphgen/models/{evaluate → evaluator}/base_evaluator.py RENAMED
File without changes
graphgen/models/{evaluate → evaluator}/length_evaluator.py RENAMED
@@ -1,7 +1,7 @@
1
  from dataclasses import dataclass
2
 
3
  from graphgen.bases.datatypes import QAPair
4
- from graphgen.models.evaluate.base_evaluator import BaseEvaluator
5
  from graphgen.models.tokenizer import Tokenizer
6
  from graphgen.utils import create_event_loop
7
 
 
1
  from dataclasses import dataclass
2
 
3
  from graphgen.bases.datatypes import QAPair
4
+ from graphgen.models.evaluator.base_evaluator import BaseEvaluator
5
  from graphgen.models.tokenizer import Tokenizer
6
  from graphgen.utils import create_event_loop
7
 
graphgen/models/{evaluate → evaluator}/mtld_evaluator.py RENAMED
@@ -2,7 +2,7 @@ from dataclasses import dataclass, field
2
  from typing import Set
3
 
4
  from graphgen.bases.datatypes import QAPair
5
- from graphgen.models.evaluate.base_evaluator import BaseEvaluator
6
  from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language
7
 
8
  nltk_helper = NLTKHelper()
 
2
  from typing import Set
3
 
4
  from graphgen.bases.datatypes import QAPair
5
+ from graphgen.models.evaluator.base_evaluator import BaseEvaluator
6
  from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language
7
 
8
  nltk_helper = NLTKHelper()
graphgen/models/{evaluate → evaluator}/reward_evaluator.py RENAMED
File without changes
graphgen/models/{evaluate → evaluator}/uni_evaluator.py RENAMED
File without changes
graphgen/models/generator/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .aggregated_generator import AggregatedGenerator
2
+ from .atomic_generator import AtomicGenerator
3
+ from .cot_generator import CoTGenerator
4
+ from .multi_hop_generator import MultiHopGenerator
graphgen/models/generator/aggregated_generator.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ from graphgen.bases import BaseGenerator
5
+ from graphgen.templates import AGGREGATED_GENERATION_PROMPT
6
+ from graphgen.utils import compute_content_hash, detect_main_language, logger
7
+
8
+
9
+ @dataclass
10
+ class AggregatedGenerator(BaseGenerator):
11
+ """
12
+ Aggregated Generator follows a TWO-STEP process:
13
+ 1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning.
14
+ The rephrased text is considered as answer to be used in the next step.
15
+ 2. question generation: Generate relevant questions based on the rephrased text.
16
+ """
17
+
18
+ @staticmethod
19
+ def build_prompt(
20
+ batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
21
+ ) -> str:
22
+ """
23
+ Build prompts for REPHRASE.
24
+ :param batch
25
+ :return:
26
+ """
27
+ nodes, edges = batch
28
+ entities_str = "\n".join(
29
+ [
30
+ f"{index + 1}. {node[0]}: {node[1]['description']}"
31
+ for index, node in enumerate(nodes)
32
+ ]
33
+ )
34
+ relations_str = "\n".join(
35
+ [
36
+ f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
37
+ for index, edge in enumerate(edges)
38
+ ]
39
+ )
40
+ language = detect_main_language(entities_str + relations_str)
41
+
42
+ # TODO: configure add_context
43
+ # if add_context:
44
+ # original_ids = [
45
+ # node["source_id"].split("<SEP>")[0] for node in _process_nodes
46
+ # ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
47
+ # original_ids = list(set(original_ids))
48
+ # original_text = await text_chunks_storage.get_by_ids(original_ids)
49
+ # original_text = "\n".join(
50
+ # [
51
+ # f"{index + 1}. {text['content']}"
52
+ # for index, text in enumerate(original_text)
53
+ # ]
54
+ # )
55
+ prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
56
+ language=language, entities=entities_str, relationships=relations_str
57
+ )
58
+ return prompt
59
+
60
+ @staticmethod
61
+ def parse_rephrased_text(response: str) -> str:
62
+ """
63
+ Parse the rephrased text from the response.
64
+ :param response:
65
+ :return: rephrased text
66
+ """
67
+ if "Rephrased Text:" in response:
68
+ rephrased_text = response.split("Rephrased Text:")[1].strip()
69
+ elif "重述文本:" in response:
70
+ rephrased_text = response.split("重述文本:")[1].strip()
71
+ else:
72
+ rephrased_text = response.strip()
73
+ return rephrased_text.strip('"')
74
+
75
+ @staticmethod
76
+ def _build_prompt_for_question_generation(answer: str) -> str:
77
+ """
78
+ Build prompts for QUESTION GENERATION.
79
+ :param answer:
80
+ :return:
81
+ """
82
+ language = detect_main_language(answer)
83
+ prompt = AGGREGATED_GENERATION_PROMPT[language]["QUESTION_GENERATION"].format(
84
+ answer=answer
85
+ )
86
+ return prompt
87
+
88
+ @staticmethod
89
+ def parse_response(response: str) -> dict:
90
+ if response.startswith("Question:"):
91
+ question = response[len("Question:") :].strip()
92
+ elif response.startswith("问题:"):
93
+ question = response[len("问题:") :].strip()
94
+ else:
95
+ question = response.strip()
96
+ return {
97
+ "question": question,
98
+ }
99
+
100
+ async def generate(
101
+ self,
102
+ batch: tuple[
103
+ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
104
+ ],
105
+ ) -> dict[str, Any]:
106
+ """
107
+ Generate QAs based on a given batch.
108
+ :param batch
109
+ :return: QA pairs
110
+ """
111
+ result = {}
112
+ rephrasing_prompt = self.build_prompt(batch)
113
+ response = await self.llm_client.generate_answer(rephrasing_prompt)
114
+ context = self.parse_rephrased_text(response)
115
+ question_generation_prompt = self._build_prompt_for_question_generation(context)
116
+ response = await self.llm_client.generate_answer(question_generation_prompt)
117
+ question = self.parse_response(response)["question"]
118
+ logger.info("Question: %s", question)
119
+ logger.info("Answer: %s", context)
120
+ qa_pairs = {
121
+ compute_content_hash(question): {
122
+ "question": question,
123
+ "answer": context,
124
+ }
125
+ }
126
+ result.update(qa_pairs)
127
+ return result
graphgen/models/generator/atomic_generator.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ from graphgen.bases import BaseGenerator
5
+ from graphgen.templates import ATOMIC_GENERATION_PROMPT
6
+ from graphgen.utils import compute_content_hash, detect_main_language, logger
7
+
8
+
9
+ @dataclass
10
+ class AtomicGenerator(BaseGenerator):
11
+ @staticmethod
12
+ def build_prompt(
13
+ batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
14
+ ) -> str:
15
+ nodes, edges = batch
16
+ context = ""
17
+ for node in nodes:
18
+ context += f"- {node[0]}: {node[1]['description']}\n"
19
+ for edge in edges:
20
+ context += f"- {edge[0]} - {edge[1]}: {edge[2]['description']}\n"
21
+ language = detect_main_language(context)
22
+
23
+ prompt = ATOMIC_GENERATION_PROMPT[language].format(context=context)
24
+ return prompt
25
+
26
+ @staticmethod
27
+ def parse_response(response: str) -> dict:
28
+ """
29
+ AtomicGenerator normally generates one QA pair per response.
30
+ So we just need to parse one QA pair from the response.
31
+ :param response:
32
+ :return:
33
+ """
34
+ if "Question:" in response and "Answer:" in response:
35
+ question = response.split("Question:")[1].split("Answer:")[0].strip()
36
+ answer = response.split("Answer:")[1].strip()
37
+ elif "问题:" in response and "答案:" in response:
38
+ question = response.split("问题:")[1].split("答案:")[0].strip()
39
+ answer = response.split("答案:")[1].strip()
40
+ else:
41
+ logger.warning("Failed to parse response: %s", response)
42
+ return {}
43
+ question = question.strip('"')
44
+ answer = answer.strip('"')
45
+ logger.info("Question: %s", question)
46
+ logger.info("Answer: %s", answer)
47
+ return {
48
+ compute_content_hash(question): {
49
+ "question": question,
50
+ "answer": answer,
51
+ }
52
+ }
graphgen/models/generator/cot_generator.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ from graphgen.bases import BaseGenerator
5
+ from graphgen.templates import COT_GENERATION_PROMPT
6
+ from graphgen.utils import compute_content_hash, detect_main_language, logger
7
+
8
+
9
+ @dataclass
10
+ class CoTGenerator(BaseGenerator):
11
+ @staticmethod
12
+ def build_prompt(
13
+ batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
14
+ ) -> str:
15
+ """
16
+ Build prompts for COT Template Design.
17
+ :param batch:
18
+ :return:
19
+ """
20
+ nodes, edges = batch
21
+ entities_str = "\n".join(
22
+ [
23
+ f"{index + 1}. {node[0]}: {node[1]['description']}"
24
+ for index, node in enumerate(nodes)
25
+ ]
26
+ )
27
+ relationships_str = "\n".join(
28
+ [
29
+ f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
30
+ for index, edge in enumerate(edges)
31
+ ]
32
+ )
33
+ language = detect_main_language(entities_str + relationships_str)
34
+ prompt = COT_GENERATION_PROMPT[language]["COT_TEMPLATE_DESIGN"].format(
35
+ entities=entities_str, relationships=relationships_str
36
+ )
37
+ return prompt
38
+
39
+ @staticmethod
40
+ def build_prompt_for_cot_generation(
41
+ batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]],
42
+ question: str,
43
+ reasoning_path: str,
44
+ ) -> str:
45
+ """
46
+ Build prompts for COT Generation.
47
+ """
48
+ nodes, edges = batch
49
+ entities_str = "\n".join(
50
+ [
51
+ f"{index + 1}. {node[0]}: {node[1]['description']}"
52
+ for index, node in enumerate(nodes)
53
+ ]
54
+ )
55
+ relationships_str = "\n".join(
56
+ [
57
+ f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
58
+ for index, edge in enumerate(edges)
59
+ ]
60
+ )
61
+ language = detect_main_language(entities_str + relationships_str)
62
+ prompt = COT_GENERATION_PROMPT[language]["COT_GENERATION"].format(
63
+ entities=entities_str,
64
+ relationships=relationships_str,
65
+ question=question,
66
+ reasoning_template=reasoning_path,
67
+ )
68
+ return prompt
69
+
70
+ @staticmethod
71
+ def parse_response(response: str) -> dict:
72
+ if "Question:" in response and "Reasoning-Path Design:" in response:
73
+ question = (
74
+ response.split("Question:")[1]
75
+ .split("Reasoning-Path Design:")[0]
76
+ .strip()
77
+ )
78
+ reasoning_path = response.split("Reasoning-Path Design:")[1].strip()
79
+ elif "问题:" in response and "推理路径设计:" in response:
80
+ question = response.split("问题:")[1].split("推理路径设计:")[0].strip()
81
+ reasoning_path = response.split("推理路径设计:")[1].strip()
82
+ else:
83
+ logger.warning("Failed to parse CoT template: %s", response)
84
+ return {}
85
+
86
+ question = question.strip('"')
87
+ reasoning_path = reasoning_path.strip('"')
88
+ logger.info("CoT Question: %s", question)
89
+ logger.info("CoT Reasoning Path: %s", reasoning_path)
90
+ return {
91
+ "question": question,
92
+ "reasoning_path": reasoning_path,
93
+ }
94
+
95
+ async def generate(
96
+ self,
97
+ batch: tuple[
98
+ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
99
+ ],
100
+ ) -> dict[str, Any]:
101
+ """
102
+ Generate QAs based on a given batch.
103
+ :param batch
104
+ :return: QA pairs
105
+ """
106
+ result = {}
107
+ prompt = self.build_prompt(batch)
108
+ response = await self.llm_client.generate_answer(prompt)
109
+ response = self.parse_response(response)
110
+ question, reasoning_path = response["question"], response["reasoning_path"]
111
+ prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
112
+ cot_answer = await self.llm_client.generate_answer(prompt)
113
+ logger.info("CoT Answer: %s", cot_answer)
114
+ qa_pairs = {
115
+ compute_content_hash(question): {
116
+ "question": question,
117
+ "answer": cot_answer,
118
+ "reasoning_path": reasoning_path,
119
+ }
120
+ }
121
+ result.update(qa_pairs)
122
+ return result
graphgen/models/generator/multi_hop_generator.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ from graphgen.bases import BaseGenerator
5
+ from graphgen.templates import MULTI_HOP_GENERATION_PROMPT
6
+ from graphgen.utils import compute_content_hash, detect_main_language, logger
7
+
8
+
9
+ @dataclass
10
+ class MultiHopGenerator(BaseGenerator):
11
+ @staticmethod
12
+ def build_prompt(
13
+ batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
14
+ ) -> str:
15
+ nodes, edges = batch
16
+ entities_str = "\n".join(
17
+ [
18
+ f"{index + 1}. {node[0]}: {node[1]['description']}"
19
+ for index, node in enumerate(nodes)
20
+ ]
21
+ )
22
+
23
+ relationships_str = "\n".join(
24
+ [
25
+ f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
26
+ for index, edge in enumerate(edges)
27
+ ]
28
+ )
29
+ language = detect_main_language(entities_str + relationships_str)
30
+ prompt = MULTI_HOP_GENERATION_PROMPT[language].format(
31
+ entities=entities_str, relationships=relationships_str
32
+ )
33
+ return prompt
34
+
35
+ @staticmethod
36
+ def parse_response(response: str) -> dict:
37
+ if "Question:" in response and "Answer:" in response:
38
+ question = response.split("Question:")[1].split("Answer:")[0].strip()
39
+ answer = response.split("Answer:")[1].strip()
40
+ elif "问题:" in response and "答案:" in response:
41
+ question = response.split("问题:")[1].split("答案:")[0].strip()
42
+ answer = response.split("答案:")[1].strip()
43
+ else:
44
+ logger.warning("Failed to parse response: %s", response)
45
+ return {}
46
+ question = question.strip('"')
47
+ answer = answer.strip('"')
48
+ logger.info("Question: %s", question)
49
+ logger.info("Answer: %s", answer)
50
+ return {
51
+ compute_content_hash(question): {
52
+ "question": question,
53
+ "answer": answer,
54
+ }
55
+ }
graphgen/models/kg_builder/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from .light_rag_kg_builder import LightRAGKGBuilder
graphgen/models/llm/limitter.py CHANGED
@@ -1,17 +1,17 @@
 
1
  import time
2
  from datetime import datetime, timedelta
3
- import asyncio
4
 
5
  from graphgen.utils import logger
6
 
7
 
8
  class RPM:
9
-
10
  def __init__(self, rpm: int = 1000):
11
  self.rpm = rpm
12
- self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0}
13
 
14
- def get_minute_slot(self):
 
15
  current_time = time.time()
16
  dt_object = datetime.fromtimestamp(current_time)
17
  total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
@@ -22,37 +22,35 @@ class RPM:
22
  dt_object = datetime.fromtimestamp(current)
23
  minute_slot = self.get_minute_slot()
24
 
25
- if self.record['rpm_slot'] == minute_slot:
26
  # check RPM exceed
27
- if self.record['counter'] >= self.rpm:
28
  # wait until next minute
29
- next_minute = dt_object.replace(
30
- second=0, microsecond=0) + timedelta(minutes=1)
 
31
  _next = next_minute.timestamp()
32
  sleep_time = abs(_next - current)
33
  if not silent:
34
- logger.info('RPM sleep %s', sleep_time)
35
  await asyncio.sleep(sleep_time)
36
 
37
- self.record = {
38
- 'rpm_slot': self.get_minute_slot(),
39
- 'counter': 0
40
- }
41
  else:
42
- self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0}
43
- self.record['counter'] += 1
44
 
45
  if not silent:
46
  logger.debug(self.record)
47
 
48
 
49
  class TPM:
50
-
51
  def __init__(self, tpm: int = 20000):
52
  self.tpm = tpm
53
- self.record = {'tpm_slot': self.get_minute_slot(), 'counter': 0}
54
 
55
- def get_minute_slot(self):
 
56
  current_time = time.time()
57
  dt_object = datetime.fromtimestamp(current_time)
58
  total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
@@ -64,25 +62,25 @@ class TPM:
64
  minute_slot = self.get_minute_slot()
65
 
66
  # get next slot, skip
67
- if self.record['tpm_slot'] != minute_slot:
68
- self.record = {'tpm_slot': minute_slot, 'counter': token_count}
69
  return
70
 
71
  # check RPM exceed
72
- self.record['counter'] += token_count
73
- if self.record['counter'] > self.tpm:
 
 
74
  # wait until next minute
75
- next_minute = dt_object.replace(
76
- second=0, microsecond=0) + timedelta(minutes=1)
 
77
  _next = next_minute.timestamp()
78
  sleep_time = abs(_next - current)
79
- logger.info('TPM sleep %s', sleep_time)
80
  await asyncio.sleep(sleep_time)
81
 
82
- self.record = {
83
- 'tpm_slot': self.get_minute_slot(),
84
- 'counter': token_count
85
- }
86
 
87
  if not silent:
88
  logger.debug(self.record)
 
1
+ import asyncio
2
  import time
3
  from datetime import datetime, timedelta
 
4
 
5
  from graphgen.utils import logger
6
 
7
 
8
  class RPM:
 
9
  def __init__(self, rpm: int = 1000):
10
  self.rpm = rpm
11
+ self.record = {"rpm_slot": self.get_minute_slot(), "counter": 0}
12
 
13
+ @staticmethod
14
+ def get_minute_slot():
15
  current_time = time.time()
16
  dt_object = datetime.fromtimestamp(current_time)
17
  total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
 
22
  dt_object = datetime.fromtimestamp(current)
23
  minute_slot = self.get_minute_slot()
24
 
25
+ if self.record["rpm_slot"] == minute_slot:
26
  # check RPM exceed
27
+ if self.record["counter"] >= self.rpm:
28
  # wait until next minute
29
+ next_minute = dt_object.replace(second=0, microsecond=0) + timedelta(
30
+ minutes=1
31
+ )
32
  _next = next_minute.timestamp()
33
  sleep_time = abs(_next - current)
34
  if not silent:
35
+ logger.info("RPM sleep %s", sleep_time)
36
  await asyncio.sleep(sleep_time)
37
 
38
+ self.record = {"rpm_slot": self.get_minute_slot(), "counter": 0}
 
 
 
39
  else:
40
+ self.record = {"rpm_slot": self.get_minute_slot(), "counter": 0}
41
+ self.record["counter"] += 1
42
 
43
  if not silent:
44
  logger.debug(self.record)
45
 
46
 
47
  class TPM:
 
48
  def __init__(self, tpm: int = 20000):
49
  self.tpm = tpm
50
+ self.record = {"tpm_slot": self.get_minute_slot(), "counter": 0}
51
 
52
+ @staticmethod
53
+ def get_minute_slot():
54
  current_time = time.time()
55
  dt_object = datetime.fromtimestamp(current_time)
56
  total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute
 
62
  minute_slot = self.get_minute_slot()
63
 
64
  # get next slot, skip
65
+ if self.record["tpm_slot"] != minute_slot:
66
+ self.record = {"tpm_slot": minute_slot, "counter": token_count}
67
  return
68
 
69
  # check RPM exceed
70
+ old_counter = self.record["counter"]
71
+ self.record["counter"] += token_count
72
+ if self.record["counter"] > self.tpm:
73
+ logger.info("Current TPM: %s, limit: %s", old_counter, self.tpm)
74
  # wait until next minute
75
+ next_minute = dt_object.replace(second=0, microsecond=0) + timedelta(
76
+ minutes=1
77
+ )
78
  _next = next_minute.timestamp()
79
  sleep_time = abs(_next - current)
80
+ logger.warning("TPM limit exceeded, wait %s seconds", sleep_time)
81
  await asyncio.sleep(sleep_time)
82
 
83
+ self.record = {"tpm_slot": self.get_minute_slot(), "counter": token_count}
 
 
 
84
 
85
  if not silent:
86
  logger.debug(self.record)
graphgen/models/llm/openai_client.py CHANGED
@@ -39,6 +39,8 @@ class OpenAIClient(BaseLLMClient):
39
  seed: Optional[int] = None,
40
  topk_per_token: int = 5, # number of topk tokens to generate for each token
41
  request_limit: bool = False,
 
 
42
  **kwargs: Any,
43
  ):
44
  super().__init__(**kwargs)
@@ -51,8 +53,8 @@ class OpenAIClient(BaseLLMClient):
51
 
52
  self.token_usage: list = []
53
  self.request_limit = request_limit
54
- self.rpm = RPM(rpm=1000)
55
- self.tpm = TPM(tpm=50000)
56
 
57
  self.__post_init__()
58
 
 
39
  seed: Optional[int] = None,
40
  topk_per_token: int = 5, # number of topk tokens to generate for each token
41
  request_limit: bool = False,
42
+ rpm: Optional[RPM] = None,
43
+ tpm: Optional[TPM] = None,
44
  **kwargs: Any,
45
  ):
46
  super().__init__(**kwargs)
 
53
 
54
  self.token_usage: list = []
55
  self.request_limit = request_limit
56
+ self.rpm = rpm or RPM()
57
+ self.tpm = tpm or TPM()
58
 
59
  self.__post_init__()
60
 
graphgen/models/partitioner/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .bfs_partitioner import BFSPartitioner
2
+ from .dfs_partitioner import DFSPartitioner
3
+ from .ece_partitioner import ECEPartitioner
4
+ from .leiden_partitioner import LeidenPartitioner
graphgen/models/partitioner/bfs_partitioner.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import deque
3
+ from dataclasses import dataclass
4
+ from typing import Any, List
5
+
6
+ from graphgen.bases import BaseGraphStorage, BasePartitioner
7
+ from graphgen.bases.datatypes import Community
8
+
9
+ NODE_UNIT: str = "n"
10
+ EDGE_UNIT: str = "e"
11
+
12
+
13
+ @dataclass
14
+ class BFSPartitioner(BasePartitioner):
15
+ """
16
+ BFS partitioner that partitions the graph into communities of a fixed size.
17
+ 1. Randomly choose a unit.
18
+ 2. Expand the community using BFS until the max unit size is reached.
19
+ (A unit is a node or an edge.)
20
+ """
21
+
22
+ async def partition(
23
+ self,
24
+ g: BaseGraphStorage,
25
+ max_units_per_community: int = 1,
26
+ **kwargs: Any,
27
+ ) -> List[Community]:
28
+ nodes = await g.get_all_nodes()
29
+ edges = await g.get_all_edges()
30
+
31
+ adj, _ = self._build_adjacency_list(nodes, edges)
32
+
33
+ used_n: set[str] = set()
34
+ used_e: set[frozenset[str]] = set()
35
+ communities: List[Community] = []
36
+
37
+ units = [(NODE_UNIT, n[0]) for n in nodes] + [
38
+ (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
39
+ ]
40
+ random.shuffle(units)
41
+
42
+ for kind, seed in units:
43
+ if (kind == NODE_UNIT and seed in used_n) or (
44
+ kind == EDGE_UNIT and seed in used_e
45
+ ):
46
+ continue
47
+
48
+ comm_n: List[str] = []
49
+ comm_e: List[tuple[str, str]] = []
50
+ queue: deque[tuple[str, Any]] = deque([(kind, seed)])
51
+ cnt = 0
52
+
53
+ while queue and cnt < max_units_per_community:
54
+ k, it = queue.popleft()
55
+ if k == NODE_UNIT:
56
+ if it in used_n:
57
+ continue
58
+ used_n.add(it)
59
+ comm_n.append(it)
60
+ cnt += 1
61
+ for nei in adj[it]:
62
+ e_key = frozenset((it, nei))
63
+ if e_key not in used_e:
64
+ queue.append((EDGE_UNIT, e_key))
65
+ else:
66
+ if it in used_e:
67
+ continue
68
+ used_e.add(it)
69
+
70
+ u, v = it
71
+ comm_e.append((u, v))
72
+ cnt += 1
73
+ # push nodes that are not visited
74
+ for n in it:
75
+ if n not in used_n:
76
+ queue.append((NODE_UNIT, n))
77
+
78
+ if comm_n or comm_e:
79
+ communities.append(
80
+ Community(id=len(communities), nodes=comm_n, edges=comm_e)
81
+ )
82
+
83
+ return communities
graphgen/models/partitioner/dfs_partitioner.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass
3
+ from typing import Any, List
4
+
5
+ from graphgen.bases import BaseGraphStorage, BasePartitioner
6
+ from graphgen.bases.datatypes import Community
7
+
8
+ NODE_UNIT: str = "n"
9
+ EDGE_UNIT: str = "e"
10
+
11
+
12
+ @dataclass
13
+ class DFSPartitioner(BasePartitioner):
14
+ """
15
+ DFS partitioner that partitions the graph into communities of a fixed size.
16
+ 1. Randomly choose a unit.
17
+ 2. Random walk using DFS until the community reaches the max unit size.
18
+ (In GraphGen, a unit is defined as a node or an edge.)
19
+ """
20
+
21
+ async def partition(
22
+ self,
23
+ g: BaseGraphStorage,
24
+ max_units_per_community: int = 1,
25
+ **kwargs: Any,
26
+ ) -> List[Community]:
27
+ nodes = await g.get_all_nodes()
28
+ edges = await g.get_all_edges()
29
+
30
+ adj, _ = self._build_adjacency_list(nodes, edges)
31
+
32
+ used_n: set[str] = set()
33
+ used_e: set[frozenset[str]] = set()
34
+ communities: List[Community] = []
35
+
36
+ units = [(NODE_UNIT, n[0]) for n in nodes] + [
37
+ (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
38
+ ]
39
+ random.shuffle(units)
40
+
41
+ for kind, seed in units:
42
+ if (kind == NODE_UNIT and seed in used_n) or (
43
+ kind == EDGE_UNIT and seed in used_e
44
+ ):
45
+ continue
46
+
47
+ comm_n, comm_e = [], []
48
+ stack = [(kind, seed)]
49
+ cnt = 0
50
+
51
+ while stack and cnt < max_units_per_community:
52
+ k, it = stack.pop()
53
+ if k == NODE_UNIT:
54
+ if it in used_n:
55
+ continue
56
+ used_n.add(it)
57
+ comm_n.append(it)
58
+ cnt += 1
59
+ for nei in adj[it]:
60
+ e_key = frozenset((it, nei))
61
+ if e_key not in used_e:
62
+ stack.append((EDGE_UNIT, e_key))
63
+ break
64
+ else:
65
+ if it in used_e:
66
+ continue
67
+ used_e.add(it)
68
+ comm_e.append(tuple(it))
69
+ cnt += 1
70
+ # push neighboring nodes
71
+ for n in it:
72
+ if n not in used_n:
73
+ stack.append((NODE_UNIT, n))
74
+
75
+ if comm_n or comm_e:
76
+ communities.append(
77
+ Community(id=len(communities), nodes=comm_n, edges=comm_e)
78
+ )
79
+
80
+ return communities
graphgen/models/partitioner/ece_partitioner.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import random
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Set, Tuple
5
+
6
+ from tqdm.asyncio import tqdm as tqdm_async
7
+
8
+ from graphgen.bases import BaseGraphStorage
9
+ from graphgen.bases.datatypes import Community
10
+ from graphgen.models.partitioner.bfs_partitioner import BFSPartitioner
11
+
12
+ NODE_UNIT: str = "n"
13
+ EDGE_UNIT: str = "e"
14
+
15
+
16
+ @dataclass
17
+ class ECEPartitioner(BFSPartitioner):
18
+ """
19
+ ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE).
20
+ We calculate ECE for edges in KG (represented as 'comprehension loss')
21
+ and group edges with similar ECE values into the same community.
22
+ 1. Select a sampling strategy.
23
+ 2. Choose a unit based on the sampling strategy.
24
+ 2. Expand the community using BFS.
25
+ 3. When expending, prefer to add units with the sampling strategy.
26
+ 4. Stop when the max unit size is reached or the max input length is reached.
27
+ (A unit is a node or an edge.)
28
+ """
29
+
30
+ @staticmethod
31
+ def _sort_units(units: list, edge_sampling: str) -> list:
32
+ """
33
+ Sort units with edge sampling strategy
34
+
35
+ :param units: total units
36
+ :param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
37
+ :return: sorted units
38
+ """
39
+ if edge_sampling == "random":
40
+ random.shuffle(units)
41
+ elif edge_sampling == "min_loss":
42
+ units = sorted(
43
+ units,
44
+ key=lambda x: x[-1]["loss"],
45
+ )
46
+ elif edge_sampling == "max_loss":
47
+ units = sorted(
48
+ units,
49
+ key=lambda x: x[-1]["loss"],
50
+ reverse=True,
51
+ )
52
+ else:
53
+ raise ValueError(f"Invalid edge sampling: {edge_sampling}")
54
+ return units
55
+
56
+ async def partition(
57
+ self,
58
+ g: BaseGraphStorage,
59
+ max_units_per_community: int = 10,
60
+ min_units_per_community: int = 1,
61
+ max_tokens_per_community: int = 10240,
62
+ unit_sampling: str = "random",
63
+ **kwargs: Any,
64
+ ) -> List[Community]:
65
+ nodes: List[Tuple[str, dict]] = await g.get_all_nodes()
66
+ edges: List[Tuple[str, str, dict]] = await g.get_all_edges()
67
+
68
+ adj, _ = self._build_adjacency_list(nodes, edges)
69
+ node_dict = dict(nodes)
70
+ edge_dict = {frozenset((u, v)): d for u, v, d in edges}
71
+
72
+ all_units: List[Tuple[str, Any, dict]] = [
73
+ (NODE_UNIT, nid, d) for nid, d in nodes
74
+ ] + [(EDGE_UNIT, frozenset((u, v)), d) for u, v, d in edges]
75
+
76
+ used_n: Set[str] = set()
77
+ used_e: Set[frozenset[str]] = set()
78
+ communities: List = []
79
+
80
+ all_units = self._sort_units(all_units, unit_sampling)
81
+
82
+ async def _grow_community(
83
+ seed_unit: Tuple[str, Any, dict]
84
+ ) -> Optional[Community]:
85
+ nonlocal used_n, used_e
86
+
87
+ community_nodes: Dict[str, dict] = {}
88
+ community_edges: Dict[frozenset[str], dict] = {}
89
+ queue: asyncio.Queue = asyncio.Queue()
90
+ token_sum = 0
91
+
92
+ async def _add_unit(u):
93
+ nonlocal token_sum
94
+ t, i, d = u
95
+ if t == NODE_UNIT: # node
96
+ if i in used_n or i in community_nodes:
97
+ return False
98
+ community_nodes[i] = d
99
+ used_n.add(i)
100
+ else: # edge
101
+ if i in used_e or i in community_edges:
102
+ return False
103
+ community_edges[i] = d
104
+ used_e.add(i)
105
+ token_sum += d.get("length", 0)
106
+ return True
107
+
108
+ await _add_unit(seed_unit)
109
+ await queue.put(seed_unit)
110
+
111
+ # BFS
112
+ while not queue.empty():
113
+ if (
114
+ len(community_nodes) + len(community_edges)
115
+ >= max_units_per_community
116
+ or token_sum >= max_tokens_per_community
117
+ ):
118
+ break
119
+
120
+ cur_type, cur_id, _ = await queue.get()
121
+
122
+ neighbors: List[Tuple[str, Any, dict]] = []
123
+ if cur_type == NODE_UNIT:
124
+ for nb_id in adj.get(cur_id, []):
125
+ e_key = frozenset((cur_id, nb_id))
126
+ if e_key not in used_e and e_key not in community_edges:
127
+ neighbors.append((EDGE_UNIT, e_key, edge_dict[e_key]))
128
+ else:
129
+ for n_id in cur_id:
130
+ if n_id not in used_n and n_id not in community_nodes:
131
+ neighbors.append((NODE_UNIT, n_id, node_dict[n_id]))
132
+
133
+ neighbors = self._sort_units(neighbors, unit_sampling)
134
+ for nb in neighbors:
135
+ if (
136
+ len(community_nodes) + len(community_edges)
137
+ >= max_units_per_community
138
+ or token_sum >= max_tokens_per_community
139
+ ):
140
+ break
141
+ if await _add_unit(nb):
142
+ await queue.put(nb)
143
+
144
+ if len(community_nodes) + len(community_edges) < min_units_per_community:
145
+ return None
146
+
147
+ return Community(
148
+ id=len(communities),
149
+ nodes=list(community_nodes.keys()),
150
+ edges=[(u, v) for (u, v), _ in community_edges.items()],
151
+ )
152
+
153
+ async for unit in tqdm_async(all_units, desc="ECE partition"):
154
+ utype, uid, _ = unit
155
+ if (utype == NODE_UNIT and uid in used_n) or (
156
+ utype == EDGE_UNIT and uid in used_e
157
+ ):
158
+ continue
159
+ comm = await _grow_community(unit)
160
+ if comm is not None:
161
+ communities.append(comm)
162
+
163
+ return communities
graphgen/models/partitioner/leiden_partitioner.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Set, Tuple
4
+
5
+ import igraph as ig
6
+ from leidenalg import ModularityVertexPartition, find_partition
7
+
8
+ from graphgen.bases import BaseGraphStorage, BasePartitioner
9
+ from graphgen.bases.datatypes import Community
10
+
11
+
12
+ @dataclass
13
+ class LeidenPartitioner(BasePartitioner):
14
+ """
15
+ Leiden partitioner that partitions the graph into communities using the Leiden algorithm.
16
+ """
17
+
18
+ async def partition(
19
+ self,
20
+ g: BaseGraphStorage,
21
+ max_size: int = 20,
22
+ use_lcc: bool = False,
23
+ random_seed: int = 42,
24
+ **kwargs: Any,
25
+ ) -> List[Community]:
26
+ """
27
+ Leiden Partition follows these steps:
28
+ 1. export the graph from graph storage
29
+ 2. use the leiden algorithm to detect communities, get {node: community_id}
30
+ 3. split large communities if max_size is given
31
+ 4. convert {node: community_id} to List[Community]
32
+ :param g
33
+ :param max_size: maximum size of each community, if None or <=0, no limit
34
+ :param use_lcc: whether to use the largest connected component only
35
+ :param random_seed
36
+ :param kwargs: other parameters for the leiden algorithm
37
+ :return:
38
+ """
39
+ nodes = await g.get_all_nodes() # List[Tuple[str, dict]]
40
+ edges = await g.get_all_edges() # List[Tuple[str, str, dict]]
41
+
42
+ node2cid: Dict[str, int] = await self._run_leiden(
43
+ nodes, edges, use_lcc, random_seed
44
+ )
45
+
46
+ if max_size is not None and max_size > 0:
47
+ node2cid = await self._split_communities(node2cid, max_size)
48
+
49
+ cid2nodes: Dict[int, List[str]] = defaultdict(list)
50
+ for n, cid in node2cid.items():
51
+ cid2nodes[cid].append(n)
52
+
53
+ communities: List[Community] = []
54
+ for cid, nodes in cid2nodes.items():
55
+ node_set: Set[str] = set(nodes)
56
+ comm_edges: List[Tuple[str, str]] = [
57
+ (u, v) for u, v, _ in edges if u in node_set and v in node_set
58
+ ]
59
+ communities.append(Community(id=cid, nodes=nodes, edges=comm_edges))
60
+ return communities
61
+
62
+ @staticmethod
63
+ async def _run_leiden(
64
+ nodes: List[Tuple[str, dict]],
65
+ edges: List[Tuple[str, str, dict]],
66
+ use_lcc: bool = False,
67
+ random_seed: int = 42,
68
+ ) -> Dict[str, int]:
69
+ # build igraph
70
+ ig_graph = ig.Graph.TupleList(((u, v) for u, v, _ in edges), directed=False)
71
+
72
+ # remove isolated nodes
73
+ ig_graph.delete_vertices(ig_graph.vs.select(_degree_eq=0))
74
+
75
+ node2cid: Dict[str, int] = {}
76
+ if use_lcc:
77
+ lcc = ig_graph.components().giant()
78
+ partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed)
79
+ for part_id, cluster in enumerate(partition):
80
+ for v in cluster:
81
+ node2cid[lcc.vs[v]["name"]] = part_id
82
+ else:
83
+ offset = 0
84
+ for component in ig_graph.components():
85
+ subgraph = ig_graph.induced_subgraph(component)
86
+ partition = find_partition(
87
+ subgraph, ModularityVertexPartition, seed=random_seed
88
+ )
89
+ for part_id, cluster in enumerate(partition):
90
+ for v in cluster:
91
+ original_node = subgraph.vs[v]["name"]
92
+ node2cid[original_node] = part_id + offset
93
+ offset += len(partition)
94
+ return node2cid
95
+
96
+ @staticmethod
97
+ async def _split_communities(
98
+ node2cid: Dict[str, int], max_size: int
99
+ ) -> Dict[str, int]:
100
+ """
101
+ Split communities larger than max_size into smaller sub-communities.
102
+ """
103
+ cid2nodes: Dict[int, List[str]] = defaultdict(list)
104
+ for n, cid in node2cid.items():
105
+ cid2nodes[cid].append(n)
106
+
107
+ new_mapping: Dict[str, int] = {}
108
+ new_cid = 0
109
+ for nodes in cid2nodes.values():
110
+ if len(nodes) <= max_size:
111
+ for n in nodes:
112
+ new_mapping[n] = new_cid
113
+ new_cid += 1
114
+ else:
115
+ for start in range(0, len(nodes), max_size):
116
+ chunk = nodes[start : start + max_size]
117
+ for n in chunk:
118
+ new_mapping[n] = new_cid
119
+ new_cid += 1
120
+ return new_mapping
graphgen/models/storage/__init__.py CHANGED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .json_storage import JsonKVStorage, JsonListStorage
2
+ from .networkx_storage import NetworkXStorage
graphgen/models/storage/networkx_storage.py CHANGED
@@ -102,8 +102,8 @@ class NetworkXStorage(BaseGraphStorage):
102
  async def get_node(self, node_id: str) -> Union[dict, None]:
103
  return self._graph.nodes.get(node_id)
104
 
105
- async def get_all_nodes(self) -> Union[list[dict], None]:
106
- return self._graph.nodes(data=True)
107
 
108
  async def node_degree(self, node_id: str) -> int:
109
  return self._graph.degree(node_id)
@@ -116,8 +116,8 @@ class NetworkXStorage(BaseGraphStorage):
116
  ) -> Union[dict, None]:
117
  return self._graph.edges.get((source_node_id, target_node_id))
118
 
119
- async def get_all_edges(self) -> Union[list[dict], None]:
120
- return self._graph.edges(data=True)
121
 
122
  async def get_node_edges(
123
  self, source_node_id: str
 
102
  async def get_node(self, node_id: str) -> Union[dict, None]:
103
  return self._graph.nodes.get(node_id)
104
 
105
+ async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
106
+ return list(self._graph.nodes(data=True))
107
 
108
  async def node_degree(self, node_id: str) -> int:
109
  return self._graph.degree(node_id)
 
116
  ) -> Union[dict, None]:
117
  return self._graph.edges.get((source_node_id, target_node_id))
118
 
119
+ async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
120
+ return list(self._graph.edges(data=True))
121
 
122
  async def get_node_edges(
123
  self, source_node_id: str
graphgen/operators/__init__.py CHANGED
@@ -1,13 +1,8 @@
1
- from graphgen.operators.build_kg.build_kg import build_kg
2
- from graphgen.operators.generate.generate_cot import generate_cot
3
- from graphgen.operators.search.search_all import search_all
4
-
5
  from .judge import judge_statement
 
6
  from .quiz import quiz
7
  from .read import read_files
 
8
  from .split import chunk_documents
9
- from .traverse_graph import (
10
- traverse_graph_for_aggregated,
11
- traverse_graph_for_atomic,
12
- traverse_graph_for_multi_hop,
13
- )
 
1
+ from .build_kg import build_kg
2
+ from .generate import generate_qas
 
 
3
  from .judge import judge_statement
4
+ from .partition import partition_kg
5
  from .quiz import quiz
6
  from .read import read_files
7
+ from .search import search_all
8
  from .split import chunk_documents
 
 
 
 
 
graphgen/operators/build_kg/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from .build_kg import build_kg
graphgen/operators/build_kg/split_kg.py DELETED
@@ -1,382 +0,0 @@
1
- import random
2
- from collections import defaultdict
3
- from typing import Dict
4
-
5
- from tqdm.asyncio import tqdm as tqdm_async
6
-
7
- from graphgen.models import NetworkXStorage
8
- from graphgen.utils import logger
9
-
10
-
11
- async def _get_node_info(
12
- node_id: str,
13
- graph_storage: NetworkXStorage,
14
- ) -> dict:
15
- """
16
- Get node info
17
-
18
- :param node_id: node id
19
- :param graph_storage: graph storage instance
20
- :return: node info
21
- """
22
- node_data = await graph_storage.get_node(node_id)
23
- return {"node_id": node_id, **node_data}
24
-
25
-
26
- def _get_level_n_edges_by_max_width(
27
- edge_adj_list: dict,
28
- node_dict: dict,
29
- edges: list,
30
- nodes,
31
- src_edge: tuple,
32
- max_depth: int,
33
- bidirectional: bool,
34
- max_extra_edges: int,
35
- edge_sampling: str,
36
- loss_strategy: str = "only_edge",
37
- ) -> list:
38
- """
39
- Get level n edges for an edge.
40
- n is decided by max_depth in traverse_strategy
41
-
42
- :param edge_adj_list
43
- :param node_dict
44
- :param edges
45
- :param nodes
46
- :param src_edge
47
- :param max_depth
48
- :param bidirectional
49
- :param max_extra_edges
50
- :param edge_sampling
51
- :return: level n edges
52
- """
53
- src_id, tgt_id, _ = src_edge
54
-
55
- level_n_edges = []
56
-
57
- start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id}
58
-
59
- while max_depth > 0 and max_extra_edges > 0:
60
- max_depth -= 1
61
-
62
- candidate_edges = [
63
- edges[edge_id]
64
- for node in start_nodes
65
- for edge_id in edge_adj_list[node]
66
- if not edges[edge_id][2].get("visited", False)
67
- ]
68
-
69
- if not candidate_edges:
70
- break
71
-
72
- if len(candidate_edges) >= max_extra_edges:
73
- if loss_strategy == "both":
74
- er_tuples = [
75
- ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
76
- for edge in candidate_edges
77
- ]
78
- candidate_edges = _sort_tuples(er_tuples, edge_sampling)[
79
- :max_extra_edges
80
- ]
81
- elif loss_strategy == "only_edge":
82
- candidate_edges = _sort_edges(candidate_edges, edge_sampling)[
83
- :max_extra_edges
84
- ]
85
- else:
86
- raise ValueError(f"Invalid loss strategy: {loss_strategy}")
87
-
88
- for edge in candidate_edges:
89
- level_n_edges.append(edge)
90
- edge[2]["visited"] = True
91
- break
92
-
93
- max_extra_edges -= len(candidate_edges)
94
- new_start_nodes = set()
95
-
96
- for edge in candidate_edges:
97
- level_n_edges.append(edge)
98
- edge[2]["visited"] = True
99
-
100
- if not edge[0] in start_nodes:
101
- new_start_nodes.add(edge[0])
102
- if not edge[1] in start_nodes:
103
- new_start_nodes.add(edge[1])
104
-
105
- start_nodes = new_start_nodes
106
-
107
- return level_n_edges
108
-
109
-
110
- def _get_level_n_edges_by_max_tokens(
111
- edge_adj_list: dict,
112
- node_dict: dict,
113
- edges: list,
114
- nodes: list,
115
- src_edge: tuple,
116
- max_depth: int,
117
- bidirectional: bool,
118
- max_tokens: int,
119
- edge_sampling: str,
120
- loss_strategy: str = "only_edge",
121
- ) -> list:
122
- """
123
- Get level n edges for an edge.
124
- n is decided by max_depth in traverse_strategy.
125
-
126
- :param edge_adj_list
127
- :param node_dict
128
- :param edges
129
- :param nodes
130
- :param src_edge
131
- :param max_depth
132
- :param bidirectional
133
- :param max_tokens
134
- :param edge_sampling
135
- :return: level n edges
136
- """
137
- src_id, tgt_id, src_edge_data = src_edge
138
-
139
- max_tokens -= (
140
- src_edge_data["length"]
141
- + nodes[node_dict[src_id]][1]["length"]
142
- + nodes[node_dict[tgt_id]][1]["length"]
143
- )
144
-
145
- level_n_edges = []
146
-
147
- start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id}
148
- temp_nodes = {src_id, tgt_id}
149
-
150
- while max_depth > 0 and max_tokens > 0:
151
- max_depth -= 1
152
-
153
- candidate_edges = [
154
- edges[edge_id]
155
- for node in start_nodes
156
- for edge_id in edge_adj_list[node]
157
- if not edges[edge_id][2].get("visited", False)
158
- ]
159
-
160
- if not candidate_edges:
161
- break
162
-
163
- if loss_strategy == "both":
164
- er_tuples = [
165
- ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
166
- for edge in candidate_edges
167
- ]
168
- candidate_edges = _sort_tuples(er_tuples, edge_sampling)
169
- elif loss_strategy == "only_edge":
170
- candidate_edges = _sort_edges(candidate_edges, edge_sampling)
171
- else:
172
- raise ValueError(f"Invalid loss strategy: {loss_strategy}")
173
-
174
- for edge in candidate_edges:
175
- max_tokens -= edge[2]["length"]
176
- if not edge[0] in temp_nodes:
177
- max_tokens -= nodes[node_dict[edge[0]]][1]["length"]
178
- if not edge[1] in temp_nodes:
179
- max_tokens -= nodes[node_dict[edge[1]]][1]["length"]
180
-
181
- if max_tokens < 0:
182
- return level_n_edges
183
-
184
- level_n_edges.append(edge)
185
- edge[2]["visited"] = True
186
- temp_nodes.add(edge[0])
187
- temp_nodes.add(edge[1])
188
-
189
- new_start_nodes = set()
190
- for edge in candidate_edges:
191
- if not edge[0] in start_nodes:
192
- new_start_nodes.add(edge[0])
193
- if not edge[1] in start_nodes:
194
- new_start_nodes.add(edge[1])
195
-
196
- start_nodes = new_start_nodes
197
-
198
- return level_n_edges
199
-
200
-
201
- def _sort_tuples(er_tuples: list, edge_sampling: str) -> list:
202
- """
203
- Sort edges with edge sampling strategy
204
-
205
- :param er_tuples: [(nodes:list, edge:tuple)]
206
- :param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
207
- :return: sorted edges
208
- """
209
- if edge_sampling == "random":
210
- er_tuples = random.sample(er_tuples, len(er_tuples))
211
- elif edge_sampling == "min_loss":
212
- er_tuples = sorted(
213
- er_tuples,
214
- key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
215
- )
216
- elif edge_sampling == "max_loss":
217
- er_tuples = sorted(
218
- er_tuples,
219
- key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"],
220
- reverse=True,
221
- )
222
- else:
223
- raise ValueError(f"Invalid edge sampling: {edge_sampling}")
224
- edges = [edge for _, edge in er_tuples]
225
- return edges
226
-
227
-
228
- def _sort_edges(edges: list, edge_sampling: str) -> list:
229
- """
230
- Sort edges with edge sampling strategy
231
-
232
- :param edges: total edges
233
- :param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
234
- :return: sorted edges
235
- """
236
- if edge_sampling == "random":
237
- random.shuffle(edges)
238
- elif edge_sampling == "min_loss":
239
- edges = sorted(edges, key=lambda x: x[2]["loss"])
240
- elif edge_sampling == "max_loss":
241
- edges = sorted(edges, key=lambda x: x[2]["loss"], reverse=True)
242
- else:
243
- raise ValueError(f"Invalid edge sampling: {edge_sampling}")
244
- return edges
245
-
246
-
247
- async def get_batches_with_strategy( # pylint: disable=too-many-branches
248
- nodes: list,
249
- edges: list,
250
- graph_storage: NetworkXStorage,
251
- traverse_strategy: Dict,
252
- ):
253
- expand_method = traverse_strategy["expand_method"]
254
- if expand_method == "max_width":
255
- logger.info("Using max width strategy")
256
- elif expand_method == "max_tokens":
257
- logger.info("Using max tokens strategy")
258
- else:
259
- raise ValueError(f"Invalid expand method: {expand_method}")
260
-
261
- max_depth = traverse_strategy["max_depth"]
262
- edge_sampling = traverse_strategy["edge_sampling"]
263
-
264
- # 构建临接矩阵
265
- edge_adj_list = defaultdict(list)
266
- node_dict = {}
267
- processing_batches = []
268
-
269
- node_cache = {}
270
-
271
- async def get_cached_node_info(node_id: str) -> dict:
272
- if node_id not in node_cache:
273
- node_cache[node_id] = await _get_node_info(node_id, graph_storage)
274
- return node_cache[node_id]
275
-
276
- for i, (node_name, _) in enumerate(nodes):
277
- node_dict[node_name] = i
278
-
279
- if traverse_strategy["loss_strategy"] == "both":
280
- er_tuples = [
281
- ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge)
282
- for edge in edges
283
- ]
284
- edges = _sort_tuples(er_tuples, edge_sampling)
285
- elif traverse_strategy["loss_strategy"] == "only_edge":
286
- edges = _sort_edges(edges, edge_sampling)
287
- else:
288
- raise ValueError(f"Invalid loss strategy: {traverse_strategy['loss_strategy']}")
289
-
290
- for i, (src, tgt, _) in enumerate(edges):
291
- edge_adj_list[src].append(i)
292
- edge_adj_list[tgt].append(i)
293
-
294
- for edge in tqdm_async(edges, desc="Preparing batches"):
295
- if "visited" in edge[2] and edge[2]["visited"]:
296
- continue
297
-
298
- edge[2]["visited"] = True
299
-
300
- _process_nodes = []
301
- _process_edges = []
302
-
303
- src_id = edge[0]
304
- tgt_id = edge[1]
305
-
306
- _process_nodes.extend(
307
- [await get_cached_node_info(src_id), await get_cached_node_info(tgt_id)]
308
- )
309
- _process_edges.append(edge)
310
-
311
- if expand_method == "max_width":
312
- level_n_edges = _get_level_n_edges_by_max_width(
313
- edge_adj_list,
314
- node_dict,
315
- edges,
316
- nodes,
317
- edge,
318
- max_depth,
319
- traverse_strategy["bidirectional"],
320
- traverse_strategy["max_extra_edges"],
321
- edge_sampling,
322
- traverse_strategy["loss_strategy"],
323
- )
324
- else:
325
- level_n_edges = _get_level_n_edges_by_max_tokens(
326
- edge_adj_list,
327
- node_dict,
328
- edges,
329
- nodes,
330
- edge,
331
- max_depth,
332
- traverse_strategy["bidirectional"],
333
- traverse_strategy["max_tokens"],
334
- edge_sampling,
335
- traverse_strategy["loss_strategy"],
336
- )
337
-
338
- for _edge in level_n_edges:
339
- _process_nodes.append(await get_cached_node_info(_edge[0]))
340
- _process_nodes.append(await get_cached_node_info(_edge[1]))
341
- _process_edges.append(_edge)
342
-
343
- # 去重
344
- _process_nodes = list(
345
- {node["node_id"]: node for node in _process_nodes}.values()
346
- )
347
- _process_edges = list(
348
- {(edge[0], edge[1]): edge for edge in _process_edges}.values()
349
- )
350
-
351
- processing_batches.append((_process_nodes, _process_edges))
352
-
353
- logger.info("Processing batches: %d", len(processing_batches))
354
-
355
- # isolate nodes
356
- isolated_node_strategy = traverse_strategy["isolated_node_strategy"]
357
- if isolated_node_strategy == "add":
358
- processing_batches = await _add_isolated_nodes(
359
- nodes, processing_batches, graph_storage
360
- )
361
- logger.info(
362
- "Processing batches after adding isolated nodes: %d",
363
- len(processing_batches),
364
- )
365
-
366
- return processing_batches
367
-
368
-
369
- async def _add_isolated_nodes(
370
- nodes: list,
371
- processing_batches: list,
372
- graph_storage: NetworkXStorage,
373
- ) -> list:
374
- visited_nodes = set()
375
- for _process_nodes, _process_edges in processing_batches:
376
- for node in _process_nodes:
377
- visited_nodes.add(node["node_id"])
378
- for node in nodes:
379
- if node[0] not in visited_nodes:
380
- _process_nodes = [await _get_node_info(node[0], graph_storage)]
381
- processing_batches.append((_process_nodes, []))
382
- return processing_batches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/operators/generate/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from .generate_qas import generate_qas
graphgen/operators/generate/generate_cot.py DELETED
@@ -1,117 +0,0 @@
1
- import asyncio
2
- from typing import Dict, List, Tuple
3
-
4
- from tqdm.asyncio import tqdm as tqdm_async
5
-
6
- from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIClient
7
- from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
8
- from graphgen.utils import compute_content_hash, detect_main_language
9
-
10
-
11
- async def generate_cot(
12
- graph_storage: NetworkXStorage,
13
- synthesizer_llm_client: OpenAIClient,
14
- method_params: Dict = None,
15
- ):
16
- method = method_params.get("method", "leiden")
17
- detector = CommunityDetector(
18
- graph_storage=graph_storage, method=method, method_params=method_params
19
- )
20
-
21
- results = await detector.detect_communities()
22
-
23
- # Convert results to a format suitable for summarization
24
- communities = {}
25
- for node, community_id in results.items():
26
- if community_id not in communities:
27
- communities[community_id] = []
28
- communities[community_id].append(node)
29
-
30
- if not communities:
31
- return {}
32
-
33
- semaphore = asyncio.Semaphore(value=1000)
34
-
35
- async def _generate_from_single_community(
36
- c_id: int, nodes: List[str]
37
- ) -> Tuple[int, Tuple[str, str, str]]:
38
- """Summarize a single community."""
39
- async with semaphore:
40
- entities: List[str] = []
41
- relationships: List[str] = []
42
-
43
- for n in nodes:
44
- node_data = await graph_storage.get_node(n)
45
- if node_data is not None:
46
- entities.append(f"({n}: {node_data.get('description')})")
47
-
48
- edges = await graph_storage.get_node_edges(n)
49
- for edge in edges:
50
- target = edge[1]
51
- if target in nodes:
52
- edge_data = await graph_storage.get_edge(n, target)
53
- relationships.append(
54
- f"({n}) - [{edge_data['description']}] -> ({target})"
55
- )
56
-
57
- entities_str = "\n".join(entities)
58
- relationships_str = "\n".join(relationships)
59
-
60
- language = (
61
- "English"
62
- if detect_main_language(entities_str + relationships_str) == "en"
63
- else "Chinese"
64
- )
65
-
66
- prompt = COT_TEMPLATE_DESIGN_PROMPT[language]["TEMPLATE"].format(
67
- entities=entities_str,
68
- relationships=relationships_str,
69
- )
70
-
71
- cot_template = await synthesizer_llm_client.generate_answer(prompt)
72
-
73
- if "问题:" in cot_template and "推理路径设计:" in cot_template:
74
- question = cot_template.split("问题:")[1].split("推理路径设计:")[0].strip()
75
- reasoning_path = cot_template.split("推理路径设计:")[1].strip()
76
- elif (
77
- "Question:" in cot_template and "Reasoning-Path Design:" in cot_template
78
- ):
79
- question = (
80
- cot_template.split("Question:")[1]
81
- .split("Reasoning-Path Design:")[0]
82
- .strip()
83
- )
84
- reasoning_path = cot_template.split("Reasoning-Path Design:")[1].strip()
85
- else:
86
- raise ValueError("COT template format is incorrect.")
87
-
88
- prompt = COT_GENERATION_PROMPT[language]["TEMPLATE"].format(
89
- entities=entities_str,
90
- relationships=relationships_str,
91
- question=question,
92
- reasoning_template=reasoning_path,
93
- )
94
-
95
- cot_answer = await synthesizer_llm_client.generate_answer(prompt)
96
-
97
- return c_id, (question, reasoning_path, cot_answer)
98
-
99
- cid_nodes = list(communities.items())
100
-
101
- results: Dict = {}
102
- async for coro in tqdm_async(
103
- asyncio.as_completed(
104
- [_generate_from_single_community(cid, nodes) for cid, nodes in cid_nodes]
105
- ),
106
- total=len(cid_nodes),
107
- desc="[Generating COT] Generating CoT data from communities",
108
- unit="community",
109
- ):
110
- cid, (q, r, a) = await coro
111
- results[compute_content_hash(q)] = {
112
- "question": q,
113
- "reasoning_path": r,
114
- "answer": a,
115
- }
116
-
117
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/operators/generate/generate_qas.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from graphgen.bases import BaseLLMClient
4
+ from graphgen.models import (
5
+ AggregatedGenerator,
6
+ AtomicGenerator,
7
+ CoTGenerator,
8
+ MultiHopGenerator,
9
+ )
10
+ from graphgen.utils import logger, run_concurrent
11
+
12
+
13
+ async def generate_qas(
14
+ llm_client: BaseLLMClient,
15
+ batches: list[
16
+ tuple[
17
+ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
18
+ ]
19
+ ],
20
+ generation_config: dict,
21
+ ) -> list[dict[str, Any]]:
22
+ """
23
+ Generate question-answer pairs based on nodes and edges.
24
+ :param llm_client: LLM client
25
+ :param batches
26
+ :param generation_config
27
+ :return: QA pairs
28
+ """
29
+ mode = generation_config["mode"]
30
+ logger.info("[Generation] mode: %s, batches: %d", mode, len(batches))
31
+
32
+ if mode == "atomic":
33
+ generator = AtomicGenerator(llm_client)
34
+ elif mode == "aggregated":
35
+ generator = AggregatedGenerator(llm_client)
36
+ elif mode == "multi_hop":
37
+ generator = MultiHopGenerator(llm_client)
38
+ elif mode == "cot":
39
+ generator = CoTGenerator(llm_client)
40
+ else:
41
+ raise ValueError(f"Unsupported generation mode: {mode}")
42
+
43
+ results = await run_concurrent(
44
+ generator.generate,
45
+ batches,
46
+ desc="[4/4]Generating QAs",
47
+ unit="batch",
48
+ )
49
+
50
+ # format
51
+ data_format = generation_config["data_format"]
52
+ logger.info("Output data format: %s", data_format)
53
+
54
+ results = generator.format_generation_results(
55
+ results, output_data_format=data_format
56
+ )
57
+
58
+ return results
graphgen/operators/partition/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .partition_kg import partition_kg
graphgen/operators/partition/partition_kg.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from graphgen.bases import BaseGraphStorage, BaseTokenizer
4
+ from graphgen.models import (
5
+ BFSPartitioner,
6
+ DFSPartitioner,
7
+ ECEPartitioner,
8
+ LeidenPartitioner,
9
+ )
10
+ from graphgen.utils import logger
11
+
12
+ from .pre_tokenize import pre_tokenize
13
+
14
+
15
+ async def partition_kg(
16
+ kg_instance: BaseGraphStorage,
17
+ tokenizer: Any = BaseTokenizer,
18
+ partition_config: dict = None,
19
+ ) -> list[
20
+ tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]]
21
+ ]:
22
+ method = partition_config["method"]
23
+ method_params = partition_config["method_params"]
24
+ if method == "bfs":
25
+ logger.info("Partitioning knowledge graph using BFS method.")
26
+ partitioner = BFSPartitioner()
27
+ elif method == "dfs":
28
+ logger.info("Partitioning knowledge graph using DFS method.")
29
+ partitioner = DFSPartitioner()
30
+ elif method == "ece":
31
+ logger.info("Partitioning knowledge graph using ECE method.")
32
+ # TODO: before ECE partitioning, we need to:
33
+ # 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random
34
+ # 2. pre-tokenize nodes and edges to get the token length
35
+ edges = await kg_instance.get_all_edges()
36
+ nodes = await kg_instance.get_all_nodes()
37
+ await pre_tokenize(kg_instance, tokenizer, edges, nodes)
38
+ partitioner = ECEPartitioner()
39
+ elif method == "leiden":
40
+ logger.info("Partitioning knowledge graph using Leiden method.")
41
+ partitioner = LeidenPartitioner()
42
+ else:
43
+ raise ValueError(f"Unsupported partition method: {method}")
44
+
45
+ communities = await partitioner.partition(g=kg_instance, **method_params)
46
+ logger.info("Partitioned the graph into %d communities.", len(communities))
47
+ batches = await partitioner.community2batch(communities, g=kg_instance)
48
+ return batches
graphgen/operators/partition/pre_tokenize.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import List, Tuple
3
+
4
+ from graphgen.bases import BaseGraphStorage, BaseTokenizer
5
+ from graphgen.utils import run_concurrent
6
+
7
+
8
+ async def pre_tokenize(
9
+ graph_storage: BaseGraphStorage,
10
+ tokenizer: BaseTokenizer,
11
+ edges: List[Tuple],
12
+ nodes: List[Tuple],
13
+ ) -> Tuple[List, List]:
14
+ """为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。"""
15
+ sem = asyncio.Semaphore(1000)
16
+
17
+ async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple:
18
+ async with sem:
19
+ data = obj[1] if is_node else obj[2]
20
+ if "length" not in data:
21
+ loop = asyncio.get_event_loop()
22
+ data["length"] = len(
23
+ await loop.run_in_executor(
24
+ None, tokenizer.encode, data["description"]
25
+ )
26
+ )
27
+ if is_node:
28
+ await graph_storage.update_node(obj[0], obj[1])
29
+ else:
30
+ await graph_storage.update_edge(obj[0], obj[1], obj[2])
31
+ return obj
32
+
33
+ new_edges, new_nodes = await asyncio.gather(
34
+ run_concurrent(
35
+ lambda e: _patch_and_write(e, is_node=False),
36
+ edges,
37
+ desc="Pre-tokenizing edges",
38
+ ),
39
+ run_concurrent(
40
+ lambda n: _patch_and_write(n, is_node=True),
41
+ nodes,
42
+ desc="Pre-tokenizing nodes",
43
+ ),
44
+ )
45
+
46
+ await graph_storage.index_done_callback()
47
+ return new_edges, new_nodes
graphgen/operators/search/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from .search_all import search_all
graphgen/operators/traverse_graph.py DELETED
@@ -1,540 +0,0 @@
1
- import asyncio
2
- from typing import Dict
3
-
4
- import gradio as gr
5
- from tqdm.asyncio import tqdm as tqdm_async
6
-
7
- from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient, Tokenizer
8
- from graphgen.operators.build_kg.split_kg import get_batches_with_strategy
9
- from graphgen.templates import (
10
- ANSWER_REPHRASING_PROMPT,
11
- MULTI_HOP_GENERATION_PROMPT,
12
- QUESTION_GENERATION_PROMPT,
13
- )
14
- from graphgen.utils import compute_content_hash, detect_main_language, logger
15
-
16
-
17
- async def _pre_tokenize(
18
- graph_storage: NetworkXStorage, tokenizer: Tokenizer, edges: list, nodes: list
19
- ) -> tuple:
20
-
21
- sem = asyncio.Semaphore(1000)
22
-
23
- async def handle_edge(edge: tuple) -> tuple:
24
- async with sem:
25
- if "length" not in edge[2]:
26
- edge[2]["length"] = len(
27
- await asyncio.get_event_loop().run_in_executor(
28
- None, tokenizer.encode, edge[2]["description"]
29
- )
30
- )
31
- return edge
32
-
33
- async def handle_node(node: dict) -> dict:
34
- async with sem:
35
- if "length" not in node[1]:
36
- node[1]["length"] = len(
37
- await asyncio.get_event_loop().run_in_executor(
38
- None, tokenizer.encode, node[1]["description"]
39
- )
40
- )
41
- return node
42
-
43
- new_edges = []
44
- new_nodes = []
45
-
46
- for result in tqdm_async(
47
- asyncio.as_completed([handle_edge(edge) for edge in edges]),
48
- total=len(edges),
49
- desc="Pre-tokenizing edges",
50
- ):
51
- new_edge = await result
52
- await graph_storage.update_edge(new_edge[0], new_edge[1], new_edge[2])
53
- new_edges.append(new_edge)
54
-
55
- for result in tqdm_async(
56
- asyncio.as_completed([handle_node(node) for node in nodes]),
57
- total=len(nodes),
58
- desc="Pre-tokenizing nodes",
59
- ):
60
- new_node = await result
61
- await graph_storage.update_node(new_node[0], new_node[1])
62
- new_nodes.append(new_node)
63
-
64
- await graph_storage.index_done_callback()
65
- return new_edges, new_nodes
66
-
67
-
68
- async def _construct_rephrasing_prompt(
69
- _process_nodes: list,
70
- _process_edges: list,
71
- text_chunks_storage: JsonKVStorage,
72
- add_context: bool = False,
73
- ) -> str:
74
- entities = [
75
- f"{_process_node['node_id']}: {_process_node['description']}"
76
- for _process_node in _process_nodes
77
- ]
78
- relations = [
79
- f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
80
- for _process_edge in _process_edges
81
- ]
82
-
83
- entities_str = "\n".join(
84
- [f"{index + 1}. {entity}" for index, entity in enumerate(entities)]
85
- )
86
- relations_str = "\n".join(
87
- [f"{index + 1}. {relation}" for index, relation in enumerate(relations)]
88
- )
89
- language = (
90
- "Chinese"
91
- if detect_main_language(entities_str + relations_str) == "zh"
92
- else "English"
93
- )
94
-
95
- if add_context:
96
- original_ids = [
97
- node["source_id"].split("<SEP>")[0] for node in _process_nodes
98
- ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
99
-
100
- original_ids = list(set(original_ids))
101
- original_text = await text_chunks_storage.get_by_ids(original_ids)
102
- original_text = "\n".join(
103
- [
104
- f"{index + 1}. {text['content']}"
105
- for index, text in enumerate(original_text)
106
- ]
107
- )
108
-
109
- prompt = ANSWER_REPHRASING_PROMPT[language]["CONTEXT_TEMPLATE"].format(
110
- language=language,
111
- original_text=original_text,
112
- entities=entities_str,
113
- relationships=relations_str,
114
- )
115
- return prompt
116
-
117
- prompt = ANSWER_REPHRASING_PROMPT[language]["TEMPLATE"].format(
118
- language=language, entities=entities_str, relationships=relations_str
119
- )
120
- return prompt
121
-
122
-
123
- def get_average_loss(batch: tuple, loss_strategy: str) -> float:
124
- try:
125
- if loss_strategy == "only_edge":
126
- return sum(edge[2]["loss"] for edge in batch[1]) / len(batch[1])
127
- if loss_strategy == "both":
128
- return sum(edge[2]["loss"] for edge in batch[1]) + sum(
129
- node["loss"] for node in batch[0]
130
- ) / (len(batch[0]) + len(batch[1]))
131
- raise ValueError("Invalid loss strategy")
132
- except Exception as e: # pylint: disable=broad-except
133
- logger.warning(
134
- "Loss not found in some nodes or edges, setting loss to -1.0: %s", e
135
- )
136
- return -1.0
137
-
138
-
139
- def _post_process_synthetic_data(data):
140
- block = data.split("\n\n")
141
- qas = []
142
- for line in block:
143
- if "Question:" in line and "Answer:" in line:
144
- question = line.split("Question:")[1].split("Answer:")[0].strip()
145
- answer = line.split("Answer:")[1].strip()
146
- qas.append({"question": question, "answer": answer})
147
- elif "问题:" in line and "答案:" in line:
148
- question = line.split("问题:")[1].split("答案:")[0].strip()
149
- answer = line.split("答案:")[1].strip()
150
- qas.append({"question": question, "answer": answer})
151
- elif "问题:" in line and "回答:" in line:
152
- question = line.split("问题:")[1].split("回答:")[0].strip()
153
- answer = line.split("回答:")[1].strip()
154
- qas.append({"question": question, "answer": answer})
155
- return qas
156
-
157
-
158
- async def traverse_graph_for_aggregated(
159
- llm_client: OpenAIClient,
160
- tokenizer: Tokenizer,
161
- graph_storage: NetworkXStorage,
162
- traverse_strategy: Dict,
163
- text_chunks_storage: JsonKVStorage,
164
- progress_bar: gr.Progress = None,
165
- max_concurrent: int = 1000,
166
- ) -> dict:
167
- """
168
- Traverse the graph
169
-
170
- :param llm_client
171
- :param tokenizer
172
- :param graph_storage
173
- :param traverse_strategy
174
- :param text_chunks_storage
175
- :param progress_bar
176
- :param max_concurrent
177
- :return: question and answer
178
- """
179
-
180
- semaphore = asyncio.Semaphore(max_concurrent)
181
-
182
- async def _process_nodes_and_edges(
183
- _process_nodes: list,
184
- _process_edges: list,
185
- ) -> str:
186
- prompt = await _construct_rephrasing_prompt(
187
- _process_nodes, _process_edges, text_chunks_storage, add_context=False
188
- )
189
- context = await llm_client.generate_answer(prompt)
190
-
191
- # post-process the context
192
- if context.startswith("Rephrased Text:"):
193
- context = context[len("Rephrased Text:") :].strip()
194
- elif context.startswith("重述文本:"):
195
- context = context[len("重述文本:") :].strip()
196
-
197
- return context
198
-
199
- async def _process_single_batch(
200
- _process_batch: tuple, question_type: str = "single"
201
- ) -> dict:
202
- async with semaphore:
203
- context = await _process_nodes_and_edges(
204
- _process_batch[0],
205
- _process_batch[1],
206
- )
207
-
208
- language = "Chinese" if detect_main_language(context) == "zh" else "English"
209
- pre_length = sum(node["length"] for node in _process_batch[0]) + sum(
210
- edge[2]["length"] for edge in _process_batch[1]
211
- )
212
-
213
- if question_type == "single":
214
- question = await llm_client.generate_answer(
215
- QUESTION_GENERATION_PROMPT[language]["SINGLE_TEMPLATE"].format(
216
- answer=context
217
- )
218
- )
219
- if question.startswith("Question:"):
220
- question = question[len("Question:") :].strip()
221
- elif question.startswith("问题:"):
222
- question = question[len("问题:") :].strip()
223
-
224
- logger.info(
225
- "%d nodes and %d edges processed",
226
- len(_process_batch[0]),
227
- len(_process_batch[1]),
228
- )
229
- logger.info("Pre-length: %s", pre_length)
230
- logger.info("Question: %s", question)
231
- logger.info("Answer: %s", context)
232
-
233
- return {
234
- compute_content_hash(context): {
235
- "question": question,
236
- "answer": context,
237
- "loss": get_average_loss(
238
- _process_batch, traverse_strategy["loss_strategy"]
239
- ),
240
- }
241
- }
242
-
243
- content = await llm_client.generate_answer(
244
- QUESTION_GENERATION_PROMPT[language]["MULTI_TEMPLATE"].format(
245
- doc=context
246
- )
247
- )
248
- qas = _post_process_synthetic_data(content)
249
-
250
- if len(qas) == 0:
251
- logger.error(
252
- "Error occurred while processing batch, question or answer is None"
253
- )
254
- return {}
255
-
256
- final_results = {}
257
- logger.info(
258
- "%d nodes and %d edges processed",
259
- len(_process_batch[0]),
260
- len(_process_batch[1]),
261
- )
262
- logger.info("Pre-length: %s", pre_length)
263
- for qa in qas:
264
- logger.info("Question: %s", qa["question"])
265
- logger.info("Answer: %s", qa["answer"])
266
- final_results[compute_content_hash(qa["question"])] = {
267
- "question": qa["question"],
268
- "answer": qa["answer"],
269
- "loss": get_average_loss(
270
- _process_batch, traverse_strategy["loss_strategy"]
271
- ),
272
- }
273
- return final_results
274
-
275
- results = {}
276
- edges = list(await graph_storage.get_all_edges())
277
- nodes = list(await graph_storage.get_all_nodes())
278
-
279
- edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
280
-
281
- processing_batches = await get_batches_with_strategy(
282
- nodes, edges, graph_storage, traverse_strategy
283
- )
284
-
285
- for result in tqdm_async(
286
- asyncio.as_completed(
287
- [_process_single_batch(batch) for batch in processing_batches]
288
- ),
289
- total=len(processing_batches),
290
- desc="[4/4]Generating QAs",
291
- ):
292
- try:
293
- if progress_bar is not None:
294
- progress_bar(
295
- len(results) / len(processing_batches), desc="[4/4]Generating QAs"
296
- )
297
- results.update(await result)
298
- if progress_bar is not None and len(results) == len(processing_batches):
299
- progress_bar(1, desc="[4/4]Generating QAs")
300
- except Exception as e: # pylint: disable=broad-except
301
- logger.error("Error occurred while generating QA: %s", e)
302
-
303
- return results
304
-
305
-
306
- # pylint: disable=too-many-branches, too-many-statements
307
- async def traverse_graph_for_atomic(
308
- llm_client: OpenAIClient,
309
- tokenizer: Tokenizer,
310
- graph_storage: NetworkXStorage,
311
- traverse_strategy: Dict,
312
- text_chunks_storage: JsonKVStorage,
313
- progress_bar: gr.Progress = None,
314
- max_concurrent: int = 1000,
315
- ) -> dict:
316
- """
317
- Traverse the graph atomicly
318
-
319
- :param llm_client
320
- :param tokenizer
321
- :param graph_storage
322
- :param traverse_strategy
323
- :param text_chunks_storage
324
- :param progress_bar
325
- :param max_concurrent
326
- :return: question and answer
327
- """
328
-
329
- semaphore = asyncio.Semaphore(max_concurrent)
330
-
331
- def _parse_qa(qa: str) -> tuple:
332
- if "Question:" in qa and "Answer:" in qa:
333
- question = qa.split("Question:")[1].split("Answer:")[0].strip()
334
- answer = qa.split("Answer:")[1].strip()
335
- elif "问题:" in qa and "答案:" in qa:
336
- question = qa.split("问题:")[1].split("答案:")[0].strip()
337
- answer = qa.split("答案:")[1].strip()
338
- else:
339
- return None, None
340
- return question.strip('"'), answer.strip('"')
341
-
342
- async def _generate_question(node_or_edge: tuple):
343
- if len(node_or_edge) == 2:
344
- des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
345
- loss = node_or_edge[1]["loss"] if "loss" in node_or_edge[1] else -1.0
346
- else:
347
- des = node_or_edge[2]["description"]
348
- loss = node_or_edge[2]["loss"] if "loss" in node_or_edge[2] else -1.0
349
-
350
- async with semaphore:
351
- try:
352
- language = "Chinese" if detect_main_language(des) == "zh" else "English"
353
-
354
- qa = await llm_client.generate_answer(
355
- QUESTION_GENERATION_PROMPT[language]["SINGLE_QA_TEMPLATE"].format(
356
- doc=des
357
- )
358
- )
359
-
360
- question, answer = _parse_qa(qa)
361
- if question is None or answer is None:
362
- return {}
363
-
364
- question = question.strip('"')
365
- answer = answer.strip('"')
366
-
367
- logger.info("Question: %s", question)
368
- logger.info("Answer: %s", answer)
369
- return {
370
- compute_content_hash(question): {
371
- "question": question,
372
- "answer": answer,
373
- "loss": loss,
374
- }
375
- }
376
- except Exception as e: # pylint: disable=broad-except
377
- logger.error("Error occurred while generating question: %s", e)
378
- return {}
379
-
380
- results = {}
381
- edges = list(await graph_storage.get_all_edges())
382
- nodes = list(await graph_storage.get_all_nodes())
383
-
384
- edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
385
-
386
- tasks = []
387
- for node in nodes:
388
- if "<SEP>" in node[1]["description"]:
389
- description_list = node[1]["description"].split("<SEP>")
390
- for item in description_list:
391
- tasks.append((node[0], {"description": item}))
392
- if "loss" in node[1]:
393
- tasks[-1][1]["loss"] = node[1]["loss"]
394
- else:
395
- tasks.append((node[0], node[1]))
396
- for edge in edges:
397
- if "<SEP>" in edge[2]["description"]:
398
- description_list = edge[2]["description"].split("<SEP>")
399
- for item in description_list:
400
- tasks.append((edge[0], edge[1], {"description": item}))
401
- if "loss" in edge[2]:
402
- tasks[-1][2]["loss"] = edge[2]["loss"]
403
- else:
404
- tasks.append((edge[0], edge[1], edge[2]))
405
-
406
- for result in tqdm_async(
407
- asyncio.as_completed([_generate_question(task) for task in tasks]),
408
- total=len(tasks),
409
- desc="[4/4]Generating QAs",
410
- ):
411
- try:
412
- if progress_bar is not None:
413
- progress_bar(len(results) / len(tasks), desc="[4/4]Generating QAs")
414
- results.update(await result)
415
- if progress_bar is not None and len(results) == len(tasks):
416
- progress_bar(1, desc="[4/4]Generating QAs")
417
- except Exception as e: # pylint: disable=broad-except
418
- logger.error("Error occurred while generating QA: %s", e)
419
- return results
420
-
421
-
422
- async def traverse_graph_for_multi_hop(
423
- llm_client: OpenAIClient,
424
- tokenizer: Tokenizer,
425
- graph_storage: NetworkXStorage,
426
- traverse_strategy: Dict,
427
- text_chunks_storage: JsonKVStorage,
428
- progress_bar: gr.Progress = None,
429
- max_concurrent: int = 1000,
430
- ) -> dict:
431
- """
432
- Traverse the graph for multi-hop
433
-
434
- :param llm_client
435
- :param tokenizer
436
- :param graph_storage
437
- :param traverse_strategy
438
- :param text_chunks_storage
439
- :param progress_bar
440
- :param max_concurrent
441
- :return: question and answer
442
- """
443
- semaphore = asyncio.Semaphore(max_concurrent)
444
-
445
- results = {}
446
- edges = list(await graph_storage.get_all_edges())
447
- nodes = list(await graph_storage.get_all_nodes())
448
-
449
- edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
450
-
451
- processing_batches = await get_batches_with_strategy(
452
- nodes, edges, graph_storage, traverse_strategy
453
- )
454
-
455
- async def _process_single_batch(_process_batch: tuple) -> dict:
456
- async with semaphore:
457
- try:
458
- language = (
459
- "Chinese"
460
- if detect_main_language(_process_batch[0][0]["description"]) == "zh"
461
- else "English"
462
- )
463
-
464
- _process_nodes = _process_batch[0]
465
- _process_edges = _process_batch[1]
466
-
467
- entities = [
468
- f"{_process_node['node_id']}: {_process_node['description']}"
469
- for _process_node in _process_nodes
470
- ]
471
-
472
- relations = [
473
- f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}"
474
- for _process_edge in _process_edges
475
- ]
476
-
477
- entities_str = "\n".join(
478
- [f"{index + 1}. {entity}" for index, entity in enumerate(entities)]
479
- )
480
- relations_str = "\n".join(
481
- [
482
- f"{index + 1}. {relation}"
483
- for index, relation in enumerate(relations)
484
- ]
485
- )
486
-
487
- prompt = MULTI_HOP_GENERATION_PROMPT[language].format(
488
- entities=entities_str, relationships=relations_str
489
- )
490
-
491
- context = await llm_client.generate_answer(prompt)
492
-
493
- # post-process the context
494
- if "Question:" in context and "Answer:" in context:
495
- question = context.split("Question:")[1].split("Answer:")[0].strip()
496
- answer = context.split("Answer:")[1].strip()
497
- elif "问题:" in context and "答案:" in context:
498
- question = context.split("问题:")[1].split("答案:")[0].strip()
499
- answer = context.split("答案:")[1].strip()
500
- else:
501
- return {}
502
-
503
- question = question.strip('"')
504
- answer = answer.strip('"')
505
-
506
- logger.info("Question: %s", question)
507
- logger.info("Answer: %s", answer)
508
-
509
- return {
510
- compute_content_hash(question): {
511
- "question": question,
512
- "answer": answer,
513
- "loss": get_average_loss(
514
- _process_batch, traverse_strategy["loss_strategy"]
515
- ),
516
- }
517
- }
518
-
519
- except Exception as e: # pylint: disable=broad-except
520
- logger.error("Error occurred while processing batch: %s", e)
521
- return {}
522
-
523
- async for result in tqdm_async(
524
- asyncio.as_completed(
525
- [_process_single_batch(batch) for batch in processing_batches]
526
- ),
527
- total=len(processing_batches),
528
- desc="[4/4]Generating QAs",
529
- ):
530
- try:
531
- if progress_bar is not None:
532
- progress_bar(
533
- len(results) / len(processing_batches), desc="[4/4]Generating QAs"
534
- )
535
- results.update(await result)
536
- if progress_bar is not None and len(results) == len(processing_batches):
537
- progress_bar(1, desc="[4/4]Generating QAs")
538
- except Exception as e: # pylint: disable=broad-except
539
- logger.error("Error occurred while generating QA: %s", e)
540
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/templates/__init__.py CHANGED
@@ -1,10 +1,13 @@
1
- from .answer_rephrasing import ANSWER_REPHRASING_PROMPT
2
- from .community import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
3
  from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT
4
  from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
 
 
 
 
 
 
5
  from .kg_extraction import KG_EXTRACTION_PROMPT
6
  from .kg_summarization import KG_SUMMARIZATION_PROMPT
7
- from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT
8
  from .question_generation import QUESTION_GENERATION_PROMPT
9
  from .search_judgement import SEARCH_JUDGEMENT_PROMPT
10
  from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT
 
 
 
1
  from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT
2
  from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
3
+ from .generation import (
4
+ AGGREGATED_GENERATION_PROMPT,
5
+ ATOMIC_GENERATION_PROMPT,
6
+ COT_GENERATION_PROMPT,
7
+ MULTI_HOP_GENERATION_PROMPT,
8
+ )
9
  from .kg_extraction import KG_EXTRACTION_PROMPT
10
  from .kg_summarization import KG_SUMMARIZATION_PROMPT
 
11
  from .question_generation import QUESTION_GENERATION_PROMPT
12
  from .search_judgement import SEARCH_JUDGEMENT_PROMPT
13
  from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT
graphgen/templates/community/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .cot_generation import COT_GENERATION_PROMPT
2
- from .cot_template_design import COT_TEMPLATE_DESIGN_PROMPT
 
 
 
graphgen/templates/community/cot_generation.py DELETED
@@ -1,87 +0,0 @@
1
- TEMPLATE_ZH = """根据给定的知识图谱原始信息及已生成的推理路径,产出一条符合模板要求、可直接用于下游训练或推理的 CoT 数据。\
2
- CoT(Chain-of-Thought,思维链)指在回答复杂问题时,把中间推理步骤一步一步显式写出来,使推理过程透明、可追溯,而不是直接给出最终答案。
3
-
4
- -输入格式-
5
- [Entities:]
6
- (实体名:实体描述)
7
- ...
8
-
9
- [Relationships:]
10
- (来源实体)-[关系描述]->(目标实体)
11
- ...
12
-
13
- [Question and Reasoning Path:]
14
- (问题)
15
- (推理路径)
16
-
17
- -输出要求-
18
- 1. 每一步只完成一个不可分割的子任务,并用自然语言衔接,但是要避免生硬的连接词。
19
- 2. 使用中文。
20
- 3. 不要使用有序列表或编号。
21
- 4. 请直接给出答案,不要生成无关信息。
22
-
23
- -真实数据-
24
- 输入:
25
- [Entities:]:
26
- {entities}
27
-
28
- [Relationships:]:
29
- {relationships}
30
-
31
- [Question:]:
32
- {question}
33
-
34
- [Reasoning_Template:]:
35
- {reasoning_template}
36
-
37
- 输出:
38
-
39
- """
40
-
41
- TEMPLATE_EN = """Given the raw knowledge graph information and the provided reasoning-path, \
42
- produce one Chain-of-Thought (CoT) sample that strictly follows the template \
43
- and can be directly used for downstream training or inference.
44
- CoT (Chain-of-Thought) means that when answering a complex question, the intermediate reasoning steps are \
45
- explicitly written out one by one, making the reasoning process transparent and traceable instead of giving \
46
- only the final answer.
47
-
48
- -Input Format-
49
- [Entities:]:
50
- (ENTITY_NAME: ENTITY_DESCRIPTION)
51
- ...
52
-
53
- [Relationships:]:
54
- (ENTITY_SOURCE)-[RELATIONSHIP_DESCRIPTION]->(ENTITY_TARGET)
55
- ...
56
-
57
- [Question and Reasoning Path:]:
58
- (QUESTION)
59
- (REASONING_PATH)
60
-
61
- -Output Requirements-
62
- 1. Each step completes a single, indivisible sub-task and is naturally connected, avoiding abrupt transition words.
63
- 2. Use English.
64
- 3. Do not use ordered lists or numbering.
65
- 4. Do not generate extraneous information, just provide the answer.
66
-
67
- -Real Data-
68
- Input:
69
- [Entities:]:
70
- {entities}
71
-
72
- [Relationships:]:
73
- {relationships}
74
-
75
- [Question:]:
76
- {question}
77
-
78
- [Reasoning_Template:]:
79
- {reasoning_template}
80
-
81
- Output:
82
- """
83
-
84
- COT_GENERATION_PROMPT = {
85
- "Chinese": {"TEMPLATE": TEMPLATE_ZH},
86
- "English": {"TEMPLATE": TEMPLATE_EN},
87
- }