github-actions[bot] commited on
Commit
9e67c3b
·
1 Parent(s): e83bd85

Auto-sync from demo at Tue Nov 25 11:19:13 UTC 2025

Browse files
graphgen/bases/base_partitioner.py CHANGED
@@ -39,16 +39,16 @@ class BasePartitioner(ABC):
39
  edges = comm.edges
40
  nodes_data = []
41
  for node in nodes:
42
- node_data = await g.get_node(node)
43
  if node_data:
44
  nodes_data.append((node, node_data))
45
  edges_data = []
46
  for u, v in edges:
47
- edge_data = await g.get_edge(u, v)
48
  if edge_data:
49
  edges_data.append((u, v, edge_data))
50
  else:
51
- edge_data = await g.get_edge(v, u)
52
  if edge_data:
53
  edges_data.append((v, u, edge_data))
54
  batches.append((nodes_data, edges_data))
 
39
  edges = comm.edges
40
  nodes_data = []
41
  for node in nodes:
42
+ node_data = g.get_node(node)
43
  if node_data:
44
  nodes_data.append((node, node_data))
45
  edges_data = []
46
  for u, v in edges:
47
+ edge_data = g.get_edge(u, v)
48
  if edge_data:
49
  edges_data.append((u, v, edge_data))
50
  else:
51
+ edge_data = g.get_edge(v, u)
52
  if edge_data:
53
  edges_data.append((v, u, edge_data))
54
  batches.append((nodes_data, edges_data))
graphgen/bases/base_storage.py CHANGED
@@ -9,103 +9,99 @@ class StorageNameSpace:
9
  working_dir: str = None
10
  namespace: str = None
11
 
12
- async def index_done_callback(self):
13
  """commit the storage operations after indexing"""
14
 
15
- async def query_done_callback(self):
16
  """commit the storage operations after querying"""
17
 
18
 
19
  class BaseListStorage(Generic[T], StorageNameSpace):
20
- async def all_items(self) -> list[T]:
21
  raise NotImplementedError
22
 
23
- async def get_by_index(self, index: int) -> Union[T, None]:
24
  raise NotImplementedError
25
 
26
- async def append(self, data: T):
27
  raise NotImplementedError
28
 
29
- async def upsert(self, data: list[T]):
30
  raise NotImplementedError
31
 
32
- async def drop(self):
33
  raise NotImplementedError
34
 
35
 
36
  class BaseKVStorage(Generic[T], StorageNameSpace):
37
- async def all_keys(self) -> list[str]:
38
  raise NotImplementedError
39
 
40
- async def get_by_id(self, id: str) -> Union[T, None]:
41
  raise NotImplementedError
42
 
43
- async def get_by_ids(
44
  self, ids: list[str], fields: Union[set[str], None] = None
45
  ) -> list[Union[T, None]]:
46
  raise NotImplementedError
47
 
48
- async def get_all(self) -> dict[str, T]:
49
  raise NotImplementedError
50
 
51
- async def filter_keys(self, data: list[str]) -> set[str]:
52
  """return un-exist keys"""
53
  raise NotImplementedError
54
 
55
- async def upsert(self, data: dict[str, T]):
56
  raise NotImplementedError
57
 
58
- async def drop(self):
59
  raise NotImplementedError
60
 
61
 
62
  class BaseGraphStorage(StorageNameSpace):
63
- async def has_node(self, node_id: str) -> bool:
64
  raise NotImplementedError
65
 
66
- async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
67
  raise NotImplementedError
68
 
69
- async def node_degree(self, node_id: str) -> int:
70
  raise NotImplementedError
71
 
72
- async def edge_degree(self, src_id: str, tgt_id: str) -> int:
73
  raise NotImplementedError
74
 
75
- async def get_node(self, node_id: str) -> Union[dict, None]:
76
  raise NotImplementedError
77
 
78
- async def update_node(self, node_id: str, node_data: dict[str, str]):
79
  raise NotImplementedError
80
 
81
- async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
82
  raise NotImplementedError
83
 
84
- async def get_edge(
85
- self, source_node_id: str, target_node_id: str
86
- ) -> Union[dict, None]:
87
  raise NotImplementedError
88
 
89
- async def update_edge(
90
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
91
  ):
92
  raise NotImplementedError
93
 
94
- async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
95
  raise NotImplementedError
96
 
97
- async def get_node_edges(
98
- self, source_node_id: str
99
- ) -> Union[list[tuple[str, str]], None]:
100
  raise NotImplementedError
101
 
102
- async def upsert_node(self, node_id: str, node_data: dict[str, str]):
103
  raise NotImplementedError
104
 
105
- async def upsert_edge(
106
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
107
  ):
108
  raise NotImplementedError
109
 
110
- async def delete_node(self, node_id: str):
111
  raise NotImplementedError
 
9
  working_dir: str = None
10
  namespace: str = None
11
 
12
+ def index_done_callback(self):
13
  """commit the storage operations after indexing"""
14
 
15
+ def query_done_callback(self):
16
  """commit the storage operations after querying"""
17
 
18
 
19
  class BaseListStorage(Generic[T], StorageNameSpace):
20
+ def all_items(self) -> list[T]:
21
  raise NotImplementedError
22
 
23
+ def get_by_index(self, index: int) -> Union[T, None]:
24
  raise NotImplementedError
25
 
26
+ def append(self, data: T):
27
  raise NotImplementedError
28
 
29
+ def upsert(self, data: list[T]):
30
  raise NotImplementedError
31
 
32
+ def drop(self):
33
  raise NotImplementedError
34
 
35
 
36
  class BaseKVStorage(Generic[T], StorageNameSpace):
37
+ def all_keys(self) -> list[str]:
38
  raise NotImplementedError
39
 
40
+ def get_by_id(self, id: str) -> Union[T, None]:
41
  raise NotImplementedError
42
 
43
+ def get_by_ids(
44
  self, ids: list[str], fields: Union[set[str], None] = None
45
  ) -> list[Union[T, None]]:
46
  raise NotImplementedError
47
 
48
+ def get_all(self) -> dict[str, T]:
49
  raise NotImplementedError
50
 
51
+ def filter_keys(self, data: list[str]) -> set[str]:
52
  """return un-exist keys"""
53
  raise NotImplementedError
54
 
55
+ def upsert(self, data: dict[str, T]):
56
  raise NotImplementedError
57
 
58
+ def drop(self):
59
  raise NotImplementedError
60
 
61
 
62
  class BaseGraphStorage(StorageNameSpace):
63
+ def has_node(self, node_id: str) -> bool:
64
  raise NotImplementedError
65
 
66
+ def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
67
  raise NotImplementedError
68
 
69
+ def node_degree(self, node_id: str) -> int:
70
  raise NotImplementedError
71
 
72
+ def edge_degree(self, src_id: str, tgt_id: str) -> int:
73
  raise NotImplementedError
74
 
75
+ def get_node(self, node_id: str) -> Union[dict, None]:
76
  raise NotImplementedError
77
 
78
+ def update_node(self, node_id: str, node_data: dict[str, str]):
79
  raise NotImplementedError
80
 
81
+ def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
82
  raise NotImplementedError
83
 
84
+ def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
 
 
85
  raise NotImplementedError
86
 
87
+ def update_edge(
88
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
89
  ):
90
  raise NotImplementedError
91
 
92
+ def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
93
  raise NotImplementedError
94
 
95
+ def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
 
 
96
  raise NotImplementedError
97
 
98
+ def upsert_node(self, node_id: str, node_data: dict[str, str]):
99
  raise NotImplementedError
100
 
101
+ def upsert_edge(
102
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
103
  ):
104
  raise NotImplementedError
105
 
106
+ def delete_node(self, node_id: str):
107
  raise NotImplementedError
graphgen/graphgen.py CHANGED
@@ -104,15 +104,15 @@ class GraphGen:
104
  # TODO: configurable whether to use coreference resolution
105
 
106
  new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
107
- _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys()))
108
  new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
109
 
110
  if len(new_docs) == 0:
111
  logger.warning("All documents are already in the storage")
112
  return
113
 
114
- await self.full_docs_storage.upsert(new_docs)
115
- await self.full_docs_storage.index_done_callback()
116
 
117
  @op("chunk", deps=["read"])
118
  @async_to_sync_method
@@ -121,7 +121,7 @@ class GraphGen:
121
  chunk documents into smaller pieces from full_docs_storage if not already present
122
  """
123
 
124
- new_docs = await self.meta_storage.get_new_data(self.full_docs_storage)
125
  if len(new_docs) == 0:
126
  logger.warning("All documents are already in the storage")
127
  return
@@ -133,9 +133,7 @@ class GraphGen:
133
  **chunk_config,
134
  )
135
 
136
- _add_chunk_keys = await self.chunks_storage.filter_keys(
137
- list(inserting_chunks.keys())
138
- )
139
  inserting_chunks = {
140
  k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
141
  }
@@ -144,10 +142,10 @@ class GraphGen:
144
  logger.warning("All chunks are already in the storage")
145
  return
146
 
147
- await self.chunks_storage.upsert(inserting_chunks)
148
- await self.chunks_storage.index_done_callback()
149
- await self.meta_storage.mark_done(self.full_docs_storage)
150
- await self.meta_storage.index_done_callback()
151
 
152
  @op("build_kg", deps=["chunk"])
153
  @async_to_sync_method
@@ -156,7 +154,7 @@ class GraphGen:
156
  build knowledge graph from text chunks
157
  """
158
  # Step 1: get new chunks according to meta and chunks storage
159
- inserting_chunks = await self.meta_storage.get_new_data(self.chunks_storage)
160
  if len(inserting_chunks) == 0:
161
  logger.warning("All chunks are already in the storage")
162
  return
@@ -174,9 +172,9 @@ class GraphGen:
174
  return
175
 
176
  # Step 3: mark meta
177
- await self.graph_storage.index_done_callback()
178
- await self.meta_storage.mark_done(self.chunks_storage)
179
- await self.meta_storage.index_done_callback()
180
 
181
  return _add_entities_and_relations
182
 
@@ -185,7 +183,7 @@ class GraphGen:
185
  async def search(self, search_config: Dict):
186
  logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))
187
 
188
- seeds = await self.meta_storage.get_new_data(self.full_docs_storage)
189
  if len(seeds) == 0:
190
  logger.warning("All documents are already been searched")
191
  return
@@ -194,19 +192,17 @@ class GraphGen:
194
  search_config=search_config,
195
  )
196
 
197
- _add_search_keys = await self.search_storage.filter_keys(
198
- list(search_results.keys())
199
- )
200
  search_results = {
201
  k: v for k, v in search_results.items() if k in _add_search_keys
202
  }
203
  if len(search_results) == 0:
204
  logger.warning("All search results are already in the storage")
205
  return
206
- await self.search_storage.upsert(search_results)
207
- await self.search_storage.index_done_callback()
208
- await self.meta_storage.mark_done(self.full_docs_storage)
209
- await self.meta_storage.index_done_callback()
210
 
211
  @op("quiz_and_judge", deps=["build_kg"])
212
  @async_to_sync_method
@@ -240,8 +236,8 @@ class GraphGen:
240
  progress_bar=self.progress_bar,
241
  )
242
 
243
- await self.rephrase_storage.index_done_callback()
244
- await _update_relations.index_done_callback()
245
 
246
  logger.info("Shutting down trainee LLM client.")
247
  self.trainee_llm_client.shutdown()
@@ -258,7 +254,7 @@ class GraphGen:
258
  self.tokenizer_instance,
259
  partition_config,
260
  )
261
- await self.partition_storage.upsert(batches)
262
  return batches
263
 
264
  @op("extract", deps=["chunk"])
@@ -276,10 +272,10 @@ class GraphGen:
276
  logger.warning("No information extracted")
277
  return
278
 
279
- await self.extract_storage.upsert(results)
280
- await self.extract_storage.index_done_callback()
281
- await self.meta_storage.mark_done(self.chunks_storage)
282
- await self.meta_storage.index_done_callback()
283
 
284
  @op("generate", deps=["partition"])
285
  @async_to_sync_method
@@ -303,17 +299,17 @@ class GraphGen:
303
  return
304
 
305
  # Step 3: store the generated QA pairs
306
- await self.qa_storage.upsert(results)
307
- await self.qa_storage.index_done_callback()
308
 
309
  @async_to_sync_method
310
  async def clear(self):
311
- await self.full_docs_storage.drop()
312
- await self.chunks_storage.drop()
313
- await self.search_storage.drop()
314
- await self.graph_storage.clear()
315
- await self.rephrase_storage.drop()
316
- await self.qa_storage.drop()
317
 
318
  logger.info("All caches are cleared")
319
 
 
104
  # TODO: configurable whether to use coreference resolution
105
 
106
  new_docs = {compute_mm_hash(doc, prefix="doc-"): doc for doc in data}
107
+ _add_doc_keys = self.full_docs_storage.filter_keys(list(new_docs.keys()))
108
  new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
109
 
110
  if len(new_docs) == 0:
111
  logger.warning("All documents are already in the storage")
112
  return
113
 
114
+ self.full_docs_storage.upsert(new_docs)
115
+ self.full_docs_storage.index_done_callback()
116
 
117
  @op("chunk", deps=["read"])
118
  @async_to_sync_method
 
121
  chunk documents into smaller pieces from full_docs_storage if not already present
122
  """
123
 
124
+ new_docs = self.meta_storage.get_new_data(self.full_docs_storage)
125
  if len(new_docs) == 0:
126
  logger.warning("All documents are already in the storage")
127
  return
 
133
  **chunk_config,
134
  )
135
 
136
+ _add_chunk_keys = self.chunks_storage.filter_keys(list(inserting_chunks.keys()))
 
 
137
  inserting_chunks = {
138
  k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
139
  }
 
142
  logger.warning("All chunks are already in the storage")
143
  return
144
 
145
+ self.chunks_storage.upsert(inserting_chunks)
146
+ self.chunks_storage.index_done_callback()
147
+ self.meta_storage.mark_done(self.full_docs_storage)
148
+ self.meta_storage.index_done_callback()
149
 
150
  @op("build_kg", deps=["chunk"])
151
  @async_to_sync_method
 
154
  build knowledge graph from text chunks
155
  """
156
  # Step 1: get new chunks according to meta and chunks storage
157
+ inserting_chunks = self.meta_storage.get_new_data(self.chunks_storage)
158
  if len(inserting_chunks) == 0:
159
  logger.warning("All chunks are already in the storage")
160
  return
 
172
  return
173
 
174
  # Step 3: mark meta
175
+ self.graph_storage.index_done_callback()
176
+ self.meta_storage.mark_done(self.chunks_storage)
177
+ self.meta_storage.index_done_callback()
178
 
179
  return _add_entities_and_relations
180
 
 
183
  async def search(self, search_config: Dict):
184
  logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))
185
 
186
+ seeds = self.meta_storage.get_new_data(self.full_docs_storage)
187
  if len(seeds) == 0:
188
  logger.warning("All documents are already been searched")
189
  return
 
192
  search_config=search_config,
193
  )
194
 
195
+ _add_search_keys = self.search_storage.filter_keys(list(search_results.keys()))
 
 
196
  search_results = {
197
  k: v for k, v in search_results.items() if k in _add_search_keys
198
  }
199
  if len(search_results) == 0:
200
  logger.warning("All search results are already in the storage")
201
  return
202
+ self.search_storage.upsert(search_results)
203
+ self.search_storage.index_done_callback()
204
+ self.meta_storage.mark_done(self.full_docs_storage)
205
+ self.meta_storage.index_done_callback()
206
 
207
  @op("quiz_and_judge", deps=["build_kg"])
208
  @async_to_sync_method
 
236
  progress_bar=self.progress_bar,
237
  )
238
 
239
+ self.rephrase_storage.index_done_callback()
240
+ _update_relations.index_done_callback()
241
 
242
  logger.info("Shutting down trainee LLM client.")
243
  self.trainee_llm_client.shutdown()
 
254
  self.tokenizer_instance,
255
  partition_config,
256
  )
257
+ self.partition_storage.upsert(batches)
258
  return batches
259
 
260
  @op("extract", deps=["chunk"])
 
272
  logger.warning("No information extracted")
273
  return
274
 
275
+ self.extract_storage.upsert(results)
276
+ self.extract_storage.index_done_callback()
277
+ self.meta_storage.mark_done(self.chunks_storage)
278
+ self.meta_storage.index_done_callback()
279
 
280
  @op("generate", deps=["partition"])
281
  @async_to_sync_method
 
299
  return
300
 
301
  # Step 3: store the generated QA pairs
302
+ self.qa_storage.upsert(results)
303
+ self.qa_storage.index_done_callback()
304
 
305
  @async_to_sync_method
306
  async def clear(self):
307
+ self.full_docs_storage.drop()
308
+ self.chunks_storage.drop()
309
+ self.search_storage.drop()
310
+ self.graph_storage.clear()
311
+ self.rephrase_storage.drop()
312
+ self.qa_storage.drop()
313
 
314
  logger.info("All caches are cleared")
315
 
graphgen/models/kg_builder/light_rag_kg_builder.py CHANGED
@@ -105,7 +105,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
105
  source_ids = []
106
  descriptions = []
107
 
108
- node = await kg_instance.get_node(entity_name)
109
  if node is not None:
110
  entity_types.append(node["entity_type"])
111
  source_ids.extend(
@@ -134,7 +134,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
134
  "description": description,
135
  "source_id": source_id,
136
  }
137
- await kg_instance.upsert_node(entity_name, node_data=node_data)
138
 
139
  async def merge_edges(
140
  self,
@@ -146,7 +146,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
146
  source_ids = []
147
  descriptions = []
148
 
149
- edge = await kg_instance.get_edge(src_id, tgt_id)
150
  if edge is not None:
151
  source_ids.extend(
152
  split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
@@ -161,8 +161,8 @@ class LightRAGKGBuilder(BaseKGBuilder):
161
  )
162
 
163
  for insert_id in [src_id, tgt_id]:
164
- if not await kg_instance.has_node(insert_id):
165
- await kg_instance.upsert_node(
166
  insert_id,
167
  node_data={
168
  "source_id": source_id,
@@ -175,7 +175,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
175
  f"({src_id}, {tgt_id})", description
176
  )
177
 
178
- await kg_instance.upsert_edge(
179
  src_id,
180
  tgt_id,
181
  edge_data={"source_id": source_id, "description": description},
 
105
  source_ids = []
106
  descriptions = []
107
 
108
+ node = kg_instance.get_node(entity_name)
109
  if node is not None:
110
  entity_types.append(node["entity_type"])
111
  source_ids.extend(
 
134
  "description": description,
135
  "source_id": source_id,
136
  }
137
+ kg_instance.upsert_node(entity_name, node_data=node_data)
138
 
139
  async def merge_edges(
140
  self,
 
146
  source_ids = []
147
  descriptions = []
148
 
149
+ edge = kg_instance.get_edge(src_id, tgt_id)
150
  if edge is not None:
151
  source_ids.extend(
152
  split_string_by_multi_markers(edge["source_id"], ["<SEP>"])
 
161
  )
162
 
163
  for insert_id in [src_id, tgt_id]:
164
+ if not kg_instance.has_node(insert_id):
165
+ kg_instance.upsert_node(
166
  insert_id,
167
  node_data={
168
  "source_id": source_id,
 
175
  f"({src_id}, {tgt_id})", description
176
  )
177
 
178
+ kg_instance.upsert_edge(
179
  src_id,
180
  tgt_id,
181
  edge_data={"source_id": source_id, "description": description},
graphgen/models/partitioner/anchor_bfs_partitioner.py CHANGED
@@ -36,8 +36,8 @@ class AnchorBFSPartitioner(BFSPartitioner):
36
  max_units_per_community: int = 1,
37
  **kwargs: Any,
38
  ) -> List[Community]:
39
- nodes = await g.get_all_nodes() # List[tuple[id, meta]]
40
- edges = await g.get_all_edges() # List[tuple[u, v, meta]]
41
 
42
  adj, _ = self._build_adjacency_list(nodes, edges)
43
 
 
36
  max_units_per_community: int = 1,
37
  **kwargs: Any,
38
  ) -> List[Community]:
39
+ nodes = g.get_all_nodes() # List[tuple[id, meta]]
40
+ edges = g.get_all_edges() # List[tuple[u, v, meta]]
41
 
42
  adj, _ = self._build_adjacency_list(nodes, edges)
43
 
graphgen/models/partitioner/bfs_partitioner.py CHANGED
@@ -23,8 +23,8 @@ class BFSPartitioner(BasePartitioner):
23
  max_units_per_community: int = 1,
24
  **kwargs: Any,
25
  ) -> List[Community]:
26
- nodes = await g.get_all_nodes()
27
- edges = await g.get_all_edges()
28
 
29
  adj, _ = self._build_adjacency_list(nodes, edges)
30
 
 
23
  max_units_per_community: int = 1,
24
  **kwargs: Any,
25
  ) -> List[Community]:
26
+ nodes = g.get_all_nodes()
27
+ edges = g.get_all_edges()
28
 
29
  adj, _ = self._build_adjacency_list(nodes, edges)
30
 
graphgen/models/partitioner/dfs_partitioner.py CHANGED
@@ -22,8 +22,8 @@ class DFSPartitioner(BasePartitioner):
22
  max_units_per_community: int = 1,
23
  **kwargs: Any,
24
  ) -> List[Community]:
25
- nodes = await g.get_all_nodes()
26
- edges = await g.get_all_edges()
27
 
28
  adj, _ = self._build_adjacency_list(nodes, edges)
29
 
 
22
  max_units_per_community: int = 1,
23
  **kwargs: Any,
24
  ) -> List[Community]:
25
+ nodes = g.get_all_nodes()
26
+ edges = g.get_all_edges()
27
 
28
  adj, _ = self._build_adjacency_list(nodes, edges)
29
 
graphgen/models/partitioner/ece_partitioner.py CHANGED
@@ -60,8 +60,8 @@ class ECEPartitioner(BFSPartitioner):
60
  unit_sampling: str = "random",
61
  **kwargs: Any,
62
  ) -> List[Community]:
63
- nodes: List[Tuple[str, dict]] = await g.get_all_nodes()
64
- edges: List[Tuple[str, str, dict]] = await g.get_all_edges()
65
 
66
  adj, _ = self._build_adjacency_list(nodes, edges)
67
  node_dict = dict(nodes)
 
60
  unit_sampling: str = "random",
61
  **kwargs: Any,
62
  ) -> List[Community]:
63
+ nodes: List[Tuple[str, dict]] = g.get_all_nodes()
64
+ edges: List[Tuple[str, str, dict]] = g.get_all_edges()
65
 
66
  adj, _ = self._build_adjacency_list(nodes, edges)
67
  node_dict = dict(nodes)
graphgen/models/partitioner/leiden_partitioner.py CHANGED
@@ -34,8 +34,8 @@ class LeidenPartitioner(BasePartitioner):
34
  :param kwargs: other parameters for the leiden algorithm
35
  :return:
36
  """
37
- nodes = await g.get_all_nodes() # List[Tuple[str, dict]]
38
- edges = await g.get_all_edges() # List[Tuple[str, str, dict]]
39
 
40
  node2cid: Dict[str, int] = await self._run_leiden(
41
  nodes, edges, use_lcc, random_seed
 
34
  :param kwargs: other parameters for the leiden algorithm
35
  :return:
36
  """
37
+ nodes = g.get_all_nodes() # List[Tuple[str, dict]]
38
+ edges = g.get_all_edges() # List[Tuple[str, str, dict]]
39
 
40
  node2cid: Dict[str, int] = await self._run_leiden(
41
  nodes, edges, use_lcc, random_seed
graphgen/models/storage/json_storage.py CHANGED
@@ -7,7 +7,7 @@ from graphgen.utils import load_json, logger, write_json
7
 
8
  @dataclass
9
  class JsonKVStorage(BaseKVStorage):
10
- _data: dict[str, str] = None
11
 
12
  def __post_init__(self):
13
  self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
@@ -18,16 +18,16 @@ class JsonKVStorage(BaseKVStorage):
18
  def data(self):
19
  return self._data
20
 
21
- async def all_keys(self) -> list[str]:
22
  return list(self._data.keys())
23
 
24
- async def index_done_callback(self):
25
  write_json(self._data, self._file_name)
26
 
27
- async def get_by_id(self, id):
28
  return self._data.get(id, None)
29
 
30
- async def get_by_ids(self, ids, fields=None) -> list:
31
  if fields is None:
32
  return [self._data.get(id, None) for id in ids]
33
  return [
@@ -39,19 +39,19 @@ class JsonKVStorage(BaseKVStorage):
39
  for id in ids
40
  ]
41
 
42
- async def get_all(self) -> dict[str, str]:
43
  return self._data
44
 
45
- async def filter_keys(self, data: list[str]) -> set[str]:
46
  return {s for s in data if s not in self._data}
47
 
48
- async def upsert(self, data: dict):
49
  left_data = {k: v for k, v in data.items() if k not in self._data}
50
  if left_data:
51
  self._data.update(left_data)
52
  return left_data
53
 
54
- async def drop(self):
55
  if self._data:
56
  self._data.clear()
57
 
@@ -71,26 +71,26 @@ class JsonListStorage(BaseListStorage):
71
  def data(self):
72
  return self._data
73
 
74
- async def all_items(self) -> list:
75
  return self._data
76
 
77
- async def index_done_callback(self):
78
  write_json(self._data, self._file_name)
79
 
80
- async def get_by_index(self, index: int):
81
  if index < 0 or index >= len(self._data):
82
  return None
83
  return self._data[index]
84
 
85
- async def append(self, data):
86
  self._data.append(data)
87
 
88
- async def upsert(self, data: list):
89
  left_data = [d for d in data if d not in self._data]
90
  self._data.extend(left_data)
91
  return left_data
92
 
93
- async def drop(self):
94
  self._data = []
95
 
96
 
@@ -101,14 +101,14 @@ class MetaJsonKVStorage(JsonKVStorage):
101
  self._data = load_json(self._file_name) or {}
102
  logger.info("Load KV %s with %d data", self.namespace, len(self._data))
103
 
104
- async def get_new_data(self, storage_instance: "JsonKVStorage") -> dict:
105
  new_data = {}
106
  for k, v in storage_instance.data.items():
107
  if k not in self._data:
108
  new_data[k] = v
109
  return new_data
110
 
111
- async def mark_done(self, storage_instance: "JsonKVStorage"):
112
- new_data = await self.get_new_data(storage_instance)
113
  if new_data:
114
  self._data.update(new_data)
 
7
 
8
  @dataclass
9
  class JsonKVStorage(BaseKVStorage):
10
+ _data: dict[str, dict] = None
11
 
12
  def __post_init__(self):
13
  self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json")
 
18
  def data(self):
19
  return self._data
20
 
21
+ def all_keys(self) -> list[str]:
22
  return list(self._data.keys())
23
 
24
+ def index_done_callback(self):
25
  write_json(self._data, self._file_name)
26
 
27
+ def get_by_id(self, id):
28
  return self._data.get(id, None)
29
 
30
+ def get_by_ids(self, ids, fields=None) -> list:
31
  if fields is None:
32
  return [self._data.get(id, None) for id in ids]
33
  return [
 
39
  for id in ids
40
  ]
41
 
42
+ def get_all(self) -> dict[str, dict]:
43
  return self._data
44
 
45
+ def filter_keys(self, data: list[str]) -> set[str]:
46
  return {s for s in data if s not in self._data}
47
 
48
+ def upsert(self, data: dict):
49
  left_data = {k: v for k, v in data.items() if k not in self._data}
50
  if left_data:
51
  self._data.update(left_data)
52
  return left_data
53
 
54
+ def drop(self):
55
  if self._data:
56
  self._data.clear()
57
 
 
71
  def data(self):
72
  return self._data
73
 
74
+ def all_items(self) -> list:
75
  return self._data
76
 
77
+ def index_done_callback(self):
78
  write_json(self._data, self._file_name)
79
 
80
+ def get_by_index(self, index: int):
81
  if index < 0 or index >= len(self._data):
82
  return None
83
  return self._data[index]
84
 
85
+ def append(self, data):
86
  self._data.append(data)
87
 
88
+ def upsert(self, data: list):
89
  left_data = [d for d in data if d not in self._data]
90
  self._data.extend(left_data)
91
  return left_data
92
 
93
+ def drop(self):
94
  self._data = []
95
 
96
 
 
101
  self._data = load_json(self._file_name) or {}
102
  logger.info("Load KV %s with %d data", self.namespace, len(self._data))
103
 
104
+ def get_new_data(self, storage_instance: "JsonKVStorage") -> dict:
105
  new_data = {}
106
  for k, v in storage_instance.data.items():
107
  if k not in self._data:
108
  new_data[k] = v
109
  return new_data
110
 
111
+ def mark_done(self, storage_instance: "JsonKVStorage"):
112
+ new_data = self.get_new_data(storage_instance)
113
  if new_data:
114
  self._data.update(new_data)
graphgen/models/storage/networkx_storage.py CHANGED
@@ -91,60 +91,56 @@ class NetworkXStorage(BaseGraphStorage):
91
  )
92
  self._graph = preloaded_graph or nx.Graph()
93
 
94
- async def index_done_callback(self):
95
  NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
96
 
97
- async def has_node(self, node_id: str) -> bool:
98
  return self._graph.has_node(node_id)
99
 
100
- async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
101
  return self._graph.has_edge(source_node_id, target_node_id)
102
 
103
- async def get_node(self, node_id: str) -> Union[dict, None]:
104
  return self._graph.nodes.get(node_id)
105
 
106
- async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
107
  return list(self._graph.nodes(data=True))
108
 
109
- async def node_degree(self, node_id: str) -> int:
110
- return self._graph.degree(node_id)
111
 
112
- async def edge_degree(self, src_id: str, tgt_id: str) -> int:
113
- return self._graph.degree(src_id) + self._graph.degree(tgt_id)
114
 
115
- async def get_edge(
116
- self, source_node_id: str, target_node_id: str
117
- ) -> Union[dict, None]:
118
  return self._graph.edges.get((source_node_id, target_node_id))
119
 
120
- async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
121
  return list(self._graph.edges(data=True))
122
 
123
- async def get_node_edges(
124
- self, source_node_id: str
125
- ) -> Union[list[tuple[str, str]], None]:
126
  if self._graph.has_node(source_node_id):
127
  return list(self._graph.edges(source_node_id, data=True))
128
  return None
129
 
130
- async def get_graph(self) -> nx.Graph:
131
  return self._graph
132
 
133
- async def upsert_node(self, node_id: str, node_data: dict[str, str]):
134
  self._graph.add_node(node_id, **node_data)
135
 
136
- async def update_node(self, node_id: str, node_data: dict[str, str]):
137
  if self._graph.has_node(node_id):
138
  self._graph.nodes[node_id].update(node_data)
139
  else:
140
  logger.warning("Node %s not found in the graph for update.", node_id)
141
 
142
- async def upsert_edge(
143
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
144
  ):
145
  self._graph.add_edge(source_node_id, target_node_id, **edge_data)
146
 
147
- async def update_edge(
148
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
149
  ):
150
  if self._graph.has_edge(source_node_id, target_node_id):
@@ -156,7 +152,7 @@ class NetworkXStorage(BaseGraphStorage):
156
  target_node_id,
157
  )
158
 
159
- async def delete_node(self, node_id: str):
160
  """
161
  Delete a node from the graph based on the specified node_id.
162
 
@@ -168,7 +164,7 @@ class NetworkXStorage(BaseGraphStorage):
168
  else:
169
  logger.warning("Node %s not found in the graph for deletion.", node_id)
170
 
171
- async def clear(self):
172
  """
173
  Clear the graph by removing all nodes and edges.
174
  """
 
91
  )
92
  self._graph = preloaded_graph or nx.Graph()
93
 
94
+ def index_done_callback(self):
95
  NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
96
 
97
+ def has_node(self, node_id: str) -> bool:
98
  return self._graph.has_node(node_id)
99
 
100
+ def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
101
  return self._graph.has_edge(source_node_id, target_node_id)
102
 
103
+ def get_node(self, node_id: str) -> Union[dict, None]:
104
  return self._graph.nodes.get(node_id)
105
 
106
+ def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
107
  return list(self._graph.nodes(data=True))
108
 
109
+ def node_degree(self, node_id: str) -> int:
110
+ return int(self._graph.degree[node_id])
111
 
112
+ def edge_degree(self, src_id: str, tgt_id: str) -> int:
113
+ return int(self._graph.degree[src_id] + self._graph.degree[tgt_id])
114
 
115
+ def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
 
 
116
  return self._graph.edges.get((source_node_id, target_node_id))
117
 
118
+ def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
119
  return list(self._graph.edges(data=True))
120
 
121
+ def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
 
 
122
  if self._graph.has_node(source_node_id):
123
  return list(self._graph.edges(source_node_id, data=True))
124
  return None
125
 
126
+ def get_graph(self) -> nx.Graph:
127
  return self._graph
128
 
129
+ def upsert_node(self, node_id: str, node_data: dict[str, str]):
130
  self._graph.add_node(node_id, **node_data)
131
 
132
+ def update_node(self, node_id: str, node_data: dict[str, str]):
133
  if self._graph.has_node(node_id):
134
  self._graph.nodes[node_id].update(node_data)
135
  else:
136
  logger.warning("Node %s not found in the graph for update.", node_id)
137
 
138
+ def upsert_edge(
139
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
140
  ):
141
  self._graph.add_edge(source_node_id, target_node_id, **edge_data)
142
 
143
+ def update_edge(
144
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
145
  ):
146
  if self._graph.has_edge(source_node_id, target_node_id):
 
152
  target_node_id,
153
  )
154
 
155
+ def delete_node(self, node_id: str):
156
  """
157
  Delete a node from the graph based on the specified node_id.
158
 
 
164
  else:
165
  logger.warning("Node %s not found in the graph for deletion.", node_id)
166
 
167
+ def clear(self):
168
  """
169
  Clear the graph by removing all nodes and edges.
170
  """
graphgen/operators/partition/partition_kg.py CHANGED
@@ -34,8 +34,8 @@ async def partition_kg(
34
  # TODO: before ECE partitioning, we need to:
35
  # 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random
36
  # 2. pre-tokenize nodes and edges to get the token length
37
- edges = await kg_instance.get_all_edges()
38
- nodes = await kg_instance.get_all_nodes()
39
  await pre_tokenize(kg_instance, tokenizer, edges, nodes)
40
  partitioner = ECEPartitioner()
41
  elif method == "leiden":
@@ -105,7 +105,7 @@ async def _attach_by_type(
105
  image_chunks = [
106
  data
107
  for sid in source_ids
108
- if "image" in sid.lower() and (data := await chunk_storage.get_by_id(sid))
109
  ]
110
  if image_chunks:
111
  # The generator expects a dictionary with an 'img_path' key, not a list of captions.
 
34
  # TODO: before ECE partitioning, we need to:
35
  # 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random
36
  # 2. pre-tokenize nodes and edges to get the token length
37
+ edges = kg_instance.get_all_edges()
38
+ nodes = kg_instance.get_all_nodes()
39
  await pre_tokenize(kg_instance, tokenizer, edges, nodes)
40
  partitioner = ECEPartitioner()
41
  elif method == "leiden":
 
105
  image_chunks = [
106
  data
107
  for sid in source_ids
108
+ if "image" in sid.lower() and (data := chunk_storage.get_by_id(sid))
109
  ]
110
  if image_chunks:
111
  # The generator expects a dictionary with an 'img_path' key, not a list of captions.
graphgen/operators/partition/pre_tokenize.py CHANGED
@@ -29,9 +29,9 @@ async def pre_tokenize(
29
  )
30
  )
31
  if is_node:
32
- await graph_storage.update_node(obj[0], obj[1])
33
  else:
34
- await graph_storage.update_edge(obj[0], obj[1], obj[2])
35
  return obj
36
 
37
  new_edges, new_nodes = await asyncio.gather(
@@ -51,5 +51,5 @@ async def pre_tokenize(
51
  ),
52
  )
53
 
54
- await graph_storage.index_done_callback()
55
  return new_edges, new_nodes
 
29
  )
30
  )
31
  if is_node:
32
+ graph_storage.update_node(obj[0], obj[1])
33
  else:
34
+ graph_storage.update_edge(obj[0], obj[1], obj[2])
35
  return obj
36
 
37
  new_edges, new_nodes = await asyncio.gather(
 
51
  ),
52
  )
53
 
54
+ graph_storage.index_done_callback()
55
  return new_edges, new_nodes
graphgen/operators/quiz_and_judge/judge.py CHANGED
@@ -45,16 +45,14 @@ async def judge_statement( # pylint: disable=too-many-statements
45
  description = edge_data["description"]
46
 
47
  try:
48
- descriptions = await rephrase_storage.get_by_id(description)
49
  assert descriptions is not None
50
 
51
  judgements = []
52
  gts = [gt for _, gt in descriptions]
53
  for description, gt in descriptions:
54
  judgement = await trainee_llm_client.generate_topk_per_token(
55
- STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
56
- statement=description
57
- )
58
  )
59
  judgements.append(judgement[0].top_candidates)
60
 
@@ -76,10 +74,10 @@ async def judge_statement( # pylint: disable=too-many-statements
76
  logger.info("Use default loss 0.1")
77
  edge_data["loss"] = -math.log(0.1)
78
 
79
- await graph_storage.update_edge(source_id, target_id, edge_data)
80
  return source_id, target_id, edge_data
81
 
82
- edges = await graph_storage.get_all_edges()
83
 
84
  await run_concurrent(
85
  _judge_single_relation,
@@ -104,24 +102,20 @@ async def judge_statement( # pylint: disable=too-many-statements
104
  description = node_data["description"]
105
 
106
  try:
107
- descriptions = await rephrase_storage.get_by_id(description)
108
  assert descriptions is not None
109
 
110
  judgements = []
111
  gts = [gt for _, gt in descriptions]
112
  for description, gt in descriptions:
113
  judgement = await trainee_llm_client.generate_topk_per_token(
114
- STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(
115
- statement=description
116
- )
117
  )
118
  judgements.append(judgement[0].top_candidates)
119
 
120
  loss = yes_no_loss_entropy(judgements, gts)
121
 
122
- logger.debug(
123
- "Node %s description: %s loss: %s", node_id, description, loss
124
- )
125
 
126
  node_data["loss"] = loss
127
  except Exception as e: # pylint: disable=broad-except
@@ -129,10 +123,10 @@ async def judge_statement( # pylint: disable=too-many-statements
129
  logger.error("Use default loss 0.1")
130
  node_data["loss"] = -math.log(0.1)
131
 
132
- await graph_storage.update_node(node_id, node_data)
133
  return node_id, node_data
134
 
135
- nodes = await graph_storage.get_all_nodes()
136
 
137
  await run_concurrent(
138
  _judge_single_entity,
 
45
  description = edge_data["description"]
46
 
47
  try:
48
+ descriptions = rephrase_storage.get_by_id(description)
49
  assert descriptions is not None
50
 
51
  judgements = []
52
  gts = [gt for _, gt in descriptions]
53
  for description, gt in descriptions:
54
  judgement = await trainee_llm_client.generate_topk_per_token(
55
+ STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
 
 
56
  )
57
  judgements.append(judgement[0].top_candidates)
58
 
 
74
  logger.info("Use default loss 0.1")
75
  edge_data["loss"] = -math.log(0.1)
76
 
77
+ graph_storage.update_edge(source_id, target_id, edge_data)
78
  return source_id, target_id, edge_data
79
 
80
+ edges = graph_storage.get_all_edges()
81
 
82
  await run_concurrent(
83
  _judge_single_relation,
 
102
  description = node_data["description"]
103
 
104
  try:
105
+ descriptions = rephrase_storage.get_by_id(description)
106
  assert descriptions is not None
107
 
108
  judgements = []
109
  gts = [gt for _, gt in descriptions]
110
  for description, gt in descriptions:
111
  judgement = await trainee_llm_client.generate_topk_per_token(
112
+ STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
 
 
113
  )
114
  judgements.append(judgement[0].top_candidates)
115
 
116
  loss = yes_no_loss_entropy(judgements, gts)
117
 
118
+ logger.debug("Node %s description: %s loss: %s", node_id, description, loss)
 
 
119
 
120
  node_data["loss"] = loss
121
  except Exception as e: # pylint: disable=broad-except
 
123
  logger.error("Use default loss 0.1")
124
  node_data["loss"] = -math.log(0.1)
125
 
126
+ graph_storage.update_node(node_id, node_data)
127
  return node_id, node_data
128
 
129
+ nodes = graph_storage.get_all_nodes()
130
 
131
  await run_concurrent(
132
  _judge_single_entity,
graphgen/operators/quiz_and_judge/quiz.py CHANGED
@@ -31,7 +31,7 @@ async def quiz(
31
  description, template_type, gt = item
32
  try:
33
  # if rephrase_storage exists already, directly get it
34
- descriptions = await rephrase_storage.get_by_id(description)
35
  if descriptions:
36
  return None
37
 
@@ -46,8 +46,8 @@ async def quiz(
46
  logger.error("Error when quizzing description %s: %s", description, e)
47
  return None
48
 
49
- edges = await graph_storage.get_all_edges()
50
- nodes = await graph_storage.get_all_nodes()
51
 
52
  results = defaultdict(list)
53
  items = []
@@ -88,6 +88,6 @@ async def quiz(
88
 
89
  for key, value in results.items():
90
  results[key] = list(set(value))
91
- await rephrase_storage.upsert({key: results[key]})
92
 
93
  return rephrase_storage
 
31
  description, template_type, gt = item
32
  try:
33
  # if rephrase_storage exists already, directly get it
34
+ descriptions = rephrase_storage.get_by_id(description)
35
  if descriptions:
36
  return None
37
 
 
46
  logger.error("Error when quizzing description %s: %s", description, e)
47
  return None
48
 
49
+ edges = graph_storage.get_all_edges()
50
+ nodes = graph_storage.get_all_nodes()
51
 
52
  results = defaultdict(list)
53
  items = []
 
88
 
89
  for key, value in results.items():
90
  results[key] = list(set(value))
91
+ rephrase_storage.upsert({key: results[key]})
92
 
93
  return rephrase_storage
graphgen/operators/storage.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any
3
+
4
+ import ray
5
+
6
+ from graphgen.models import JsonKVStorage, JsonListStorage, NetworkXStorage
7
+
8
+
9
+ @ray.remote
10
+ class StorageManager:
11
+ """
12
+ Centralized storage for all operators
13
+
14
+ Example Usage:
15
+ ----------
16
+ # init
17
+ storage_manager = StorageManager.remote(working_dir="/path/to/dir", unique_id=123)
18
+
19
+ # visit storage in tasks
20
+ @ray.remote
21
+ def some_task(storage_manager):
22
+ full_docs_storage = ray.get(storage_manager.get_storage.remote("full_docs"))
23
+
24
+ # visit storage in other actors
25
+ @ray.remote
26
+ class SomeOperator:
27
+ def __init__(self, storage_manager):
28
+ self.storage_manager = storage_manager
29
+ def some_method(self):
30
+ full_docs_storage = ray.get(self.storage_manager.get_storage.remote("full_docs"))
31
+ """
32
+
33
+ def __init__(self, working_dir: str, unique_id: int):
34
+ self.working_dir = working_dir
35
+ self.unique_id = unique_id
36
+
37
+ # Initialize all storage backends
38
+ self.storages = {
39
+ "full_docs": JsonKVStorage(working_dir, namespace="full_docs"),
40
+ "chunks": JsonKVStorage(working_dir, namespace="chunks"),
41
+ "graph": NetworkXStorage(working_dir, namespace="graph"),
42
+ "rephrase": JsonKVStorage(working_dir, namespace="rephrase"),
43
+ "partition": JsonListStorage(working_dir, namespace="partition"),
44
+ "search": JsonKVStorage(
45
+ os.path.join(working_dir, "data", "graphgen", f"{unique_id}"),
46
+ namespace="search",
47
+ ),
48
+ "extraction": JsonKVStorage(
49
+ os.path.join(working_dir, "data", "graphgen", f"{unique_id}"),
50
+ namespace="extraction",
51
+ ),
52
+ "qa": JsonListStorage(
53
+ os.path.join(working_dir, "data", "graphgen", f"{unique_id}"),
54
+ namespace="qa",
55
+ ),
56
+ }
57
+
58
+ def get_storage(self, name: str) -> Any:
59
+ return self.storages.get(name)
requirements.txt CHANGED
@@ -12,7 +12,7 @@ nltk
12
  jieba
13
  plotly
14
  pandas
15
- gradio>=5.44.1
16
  kaleido
17
  pyyaml
18
  langcodes
 
12
  jieba
13
  plotly
14
  pandas
15
+ gradio==5.44.1
16
  kaleido
17
  pyyaml
18
  langcodes