github-actions[bot] commited on
Commit
31086ae
·
1 Parent(s): 10ba08f

Auto-sync from demo at Tue Dec 16 08:21:05 UTC 2025

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +134 -172
  2. graphgen/bases/__init__.py +3 -7
  3. graphgen/bases/base_llm_wrapper.py +0 -6
  4. graphgen/bases/base_operator.py +57 -0
  5. graphgen/bases/base_partitioner.py +22 -27
  6. graphgen/bases/base_reader.py +57 -41
  7. graphgen/bases/base_splitter.py +3 -3
  8. graphgen/bases/base_storage.py +6 -17
  9. graphgen/bases/datatypes.py +44 -0
  10. graphgen/{operators/init → common}/__init__.py +1 -0
  11. graphgen/{operators/init → common}/init_llm.py +125 -29
  12. graphgen/common/init_storage.py +262 -0
  13. graphgen/configs/aggregated_config.yaml +0 -41
  14. graphgen/configs/atomic_config.yaml +0 -31
  15. graphgen/configs/cot_config.yaml +0 -33
  16. graphgen/configs/multi_hop_config.yaml +0 -34
  17. graphgen/configs/schema_guided_extraction_config.yaml +0 -20
  18. graphgen/configs/search_dna_config.yaml +0 -17
  19. graphgen/configs/search_protein_config.yaml +0 -15
  20. graphgen/configs/search_rna_config.yaml +0 -14
  21. graphgen/configs/vqa_config.yaml +0 -32
  22. graphgen/engine.py +191 -106
  23. graphgen/graphgen.py +0 -295
  24. graphgen/models/__init__.py +7 -2
  25. graphgen/models/extractor/schema_guided_extractor.py +3 -5
  26. graphgen/models/generator/vqa_generator.py +2 -2
  27. graphgen/models/llm/local/sglang_wrapper.py +0 -12
  28. graphgen/models/llm/local/vllm_wrapper.py +35 -47
  29. graphgen/models/partitioner/anchor_bfs_partitioner.py +9 -14
  30. graphgen/models/partitioner/bfs_partitioner.py +4 -9
  31. graphgen/models/partitioner/dfs_partitioner.py +5 -9
  32. graphgen/models/partitioner/ece_partitioner.py +19 -24
  33. graphgen/models/partitioner/leiden_partitioner.py +5 -9
  34. graphgen/models/reader/__init__.py +0 -1
  35. graphgen/models/reader/csv_reader.py +14 -11
  36. graphgen/models/reader/json_reader.py +41 -14
  37. graphgen/models/reader/jsonl_reader.py +0 -30
  38. graphgen/models/reader/parquet_reader.py +16 -10
  39. graphgen/models/reader/pdf_reader.py +35 -20
  40. graphgen/models/reader/pickle_reader.py +64 -16
  41. graphgen/models/reader/rdf_reader.py +93 -13
  42. graphgen/models/reader/txt_reader.py +27 -5
  43. graphgen/models/splitter/character_splitter.py +1 -1
  44. graphgen/models/splitter/markdown_splitter.py +2 -2
  45. graphgen/models/splitter/recursive_character_splitter.py +2 -2
  46. graphgen/models/storage/__init__.py +5 -2
  47. graphgen/{configs → models/storage/graph}/__init__.py +0 -0
  48. graphgen/models/storage/graph/kuzu_storage.py +256 -0
  49. graphgen/models/storage/{networkx_storage.py → graph/networkx_storage.py} +17 -20
  50. graphgen/models/storage/kv/__init__.py +0 -0
app.py CHANGED
@@ -5,14 +5,12 @@ import tempfile
5
  from importlib.resources import files
6
 
7
  import gradio as gr
8
- import pandas as pd
9
  from dotenv import load_dotenv
10
 
11
- from graphgen.engine import Context, Engine, collect_ops
12
- from graphgen.graphgen import GraphGen
13
- from graphgen.models import OpenAIClient, Tokenizer
14
- from graphgen.models.llm.limitter import RPM, TPM
15
- from graphgen.utils import set_logger
16
  from webui.base import WebuiParams
17
  from webui.i18n import Translate
18
  from webui.i18n import gettext as _
@@ -22,7 +20,6 @@ from webui.utils import cleanup_workspace, count_tokens, preview_file, setup_wor
22
  root_dir = files("webui").parent
23
  sys.path.append(root_dir)
24
 
25
-
26
  load_dotenv()
27
 
28
  css = """
@@ -34,131 +31,136 @@ css = """
34
  """
35
 
36
 
37
- def init_graph_gen(config: dict, env: dict) -> GraphGen:
38
- # Set up working directory
39
- log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
40
- set_logger(log_file, if_stream=True)
41
- os.environ.update({k: str(v) for k, v in env.items()})
42
-
43
- tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
44
- synthesizer_llm_client = OpenAIClient(
45
- model=env.get("SYNTHESIZER_MODEL", ""),
46
- base_url=env.get("SYNTHESIZER_BASE_URL", ""),
47
- api_key=env.get("SYNTHESIZER_API_KEY", ""),
48
- request_limit=True,
49
- rpm=RPM(env.get("RPM", 1000)),
50
- tpm=TPM(env.get("TPM", 50000)),
51
- tokenizer=tokenizer_instance,
52
- )
53
- trainee_llm_client = OpenAIClient(
54
- model=env.get("TRAINEE_MODEL", ""),
55
- base_url=env.get("TRAINEE_BASE_URL", ""),
56
- api_key=env.get("TRAINEE_API_KEY", ""),
57
- request_limit=True,
58
- rpm=RPM(env.get("RPM", 1000)),
59
- tpm=TPM(env.get("TPM", 50000)),
60
- tokenizer=tokenizer_instance,
61
- )
62
-
63
- graph_gen = GraphGen(
64
- working_dir=working_dir,
65
- tokenizer_instance=tokenizer_instance,
66
- synthesizer_llm_client=synthesizer_llm_client,
67
- trainee_llm_client=trainee_llm_client,
68
- )
69
-
70
- return graph_gen
71
-
72
-
73
- # pylint: disable=too-many-statements
74
- def run_graphgen(params: WebuiParams, progress=gr.Progress()):
75
- def sum_tokens(client):
76
- return sum(u["total_tokens"] for u in client.token_usage)
77
-
78
  method = params.partition_method
79
  if method == "dfs":
80
- partition_params = {
81
  "max_units_per_community": params.dfs_max_units,
82
  }
83
- elif method == "bfs":
84
- partition_params = {
85
  "max_units_per_community": params.bfs_max_units,
86
  }
87
- elif method == "leiden":
88
- partition_params = {
89
  "max_size": params.leiden_max_size,
90
  "use_lcc": params.leiden_use_lcc,
91
  "random_seed": params.leiden_random_seed,
92
  }
93
- else: # ece
94
- partition_params = {
95
- "max_units_per_community": params.ece_max_units,
96
- "min_units_per_community": params.ece_min_units,
97
- "max_tokens_per_community": params.ece_max_tokens,
98
- "unit_sampling": params.ece_unit_sampling,
99
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- pipeline = [
 
102
  {
103
- "name": "read",
104
- "op_key": "read",
 
 
105
  "params": {
106
- "input_file": params.upload_file,
107
  },
108
  },
109
  {
110
- "name": "chunk",
111
- "deps": ["read"],
112
- "op_key": "chunk",
 
 
113
  "params": {
114
  "chunk_size": params.chunk_size,
115
  "chunk_overlap": params.chunk_overlap,
116
  },
117
  },
118
  {
119
- "name": "build_kg",
120
- "deps": ["chunk"],
121
- "op_key": "build_kg",
 
 
122
  },
123
  ]
124
 
 
 
 
125
  if params.if_trainee_model:
126
- pipeline.append(
127
- {
128
- "name": "quiz_and_judge",
129
- "deps": ["build_kg"],
130
- "op_key": "quiz_and_judge",
131
- "params": {"quiz_samples": params.quiz_samples, "re_judge": True},
132
- }
133
- )
134
- pipeline.append(
135
  {
136
- "name": "partition",
137
- "deps": ["quiz_and_judge"],
138
- "op_key": "partition",
 
 
139
  "params": {
140
- "method": params.partition_method,
141
- "method_params": partition_params,
142
  },
143
  }
144
  )
145
- else:
146
- pipeline.append(
147
  {
148
- "name": "partition",
149
- "deps": ["build_kg"],
150
- "op_key": "partition",
151
- "params": {
152
- "method": params.partition_method,
153
- "method_params": partition_params,
154
- },
155
  }
156
  )
157
- pipeline.append(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  {
159
- "name": "generate",
160
- "deps": ["partition"],
161
- "op_key": "generate",
 
 
162
  "params": {
163
  "method": params.mode,
164
  "data_format": params.data_format,
@@ -166,88 +168,50 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
166
  }
167
  )
168
 
169
- config = {
170
- "if_trainee_model": params.if_trainee_model,
171
- "read": {"input_file": params.upload_file},
172
- "pipeline": pipeline,
173
- }
174
 
175
- env = {
176
- "TOKENIZER_MODEL": params.tokenizer,
177
- "SYNTHESIZER_BASE_URL": params.synthesizer_url,
178
- "SYNTHESIZER_MODEL": params.synthesizer_model,
179
- "TRAINEE_BASE_URL": params.trainee_url,
180
- "TRAINEE_MODEL": params.trainee_model,
181
- "SYNTHESIZER_API_KEY": params.api_key,
182
- "TRAINEE_API_KEY": params.trainee_api_key,
183
- "RPM": params.rpm,
184
- "TPM": params.tpm,
185
- }
186
 
187
- # Test API connection
188
- test_api_connection(
189
- env["SYNTHESIZER_BASE_URL"],
190
- env["SYNTHESIZER_API_KEY"],
191
- env["SYNTHESIZER_MODEL"],
192
- )
193
- if config["if_trainee_model"]:
194
- test_api_connection(
195
- env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
196
- )
197
 
198
- # Initialize GraphGen
199
- graph_gen = init_graph_gen(config, env)
200
- graph_gen.clear()
201
- graph_gen.progress_bar = progress
202
 
203
- try:
204
- ctx = Context(config=config, graph_gen=graph_gen)
205
- ops = collect_ops(config, graph_gen)
206
- Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx)
207
-
208
- # Save output
209
- output_data = graph_gen.qa_storage.data
210
- with tempfile.NamedTemporaryFile(
211
- mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
212
- ) as tmpfile:
213
- json.dump(output_data, tmpfile, ensure_ascii=False)
214
- output_file = tmpfile.name
215
-
216
- synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client)
217
- trainee_tokens = (
218
- sum_tokens(graph_gen.trainee_llm_client)
219
- if config["if_trainee_model"]
220
- else 0
221
- )
222
- total_tokens = synthesizer_tokens + trainee_tokens
223
-
224
- data_frame = params.token_counter
225
- try:
226
- _update_data = [
227
- [data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)]
228
- ]
229
- new_df = pd.DataFrame(_update_data, columns=data_frame.columns)
230
- data_frame = new_df
231
-
232
- except Exception as e:
233
- raise gr.Error(f"DataFrame operation error: {str(e)}")
234
-
235
- return output_file, gr.DataFrame(
236
- label="Token Stats",
237
- headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
238
- datatype="str",
239
- interactive=False,
240
- value=data_frame,
241
- visible=True,
242
- wrap=True,
243
- )
244
 
245
  except Exception as e: # pylint: disable=broad-except
246
  raise gr.Error(f"Error occurred: {str(e)}")
247
 
248
  finally:
249
  # Clean up workspace
250
- cleanup_workspace(graph_gen.working_dir)
251
 
252
 
253
  with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
@@ -267,7 +231,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
267
  ("简体中文", "zh"),
268
  ],
269
  value="en",
270
- # label=_("Language"),
271
  render=False,
272
  container=False,
273
  elem_classes=["center-row"],
@@ -295,7 +258,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
295
  os.path.join(root_dir, "webui", "translation.json"),
296
  lang_btn,
297
  placeholder_langs=["en", "zh"],
298
- persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
299
  ):
300
  lang_btn.render()
301
 
@@ -701,7 +664,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
701
  outputs=[output, token_counter],
702
  )
703
 
704
-
705
  if __name__ == "__main__":
706
  demo.queue(api_open=False, default_concurrency_limit=2)
707
  demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)
 
5
  from importlib.resources import files
6
 
7
  import gradio as gr
8
+ import ray
9
  from dotenv import load_dotenv
10
 
11
+ from graphgen.engine import Engine
12
+ from graphgen.operators import operators
13
+ from graphgen.utils import CURRENT_LOGGER_VAR, set_logger
 
 
14
  from webui.base import WebuiParams
15
  from webui.i18n import Translate
16
  from webui.i18n import gettext as _
 
20
  root_dir = files("webui").parent
21
  sys.path.append(root_dir)
22
 
 
23
  load_dotenv()
24
 
25
  css = """
 
31
  """
32
 
33
 
34
+ def _get_partition_params(params: WebuiParams):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  method = params.partition_method
36
  if method == "dfs":
37
+ return {
38
  "max_units_per_community": params.dfs_max_units,
39
  }
40
+ if method == "bfs":
41
+ return {
42
  "max_units_per_community": params.bfs_max_units,
43
  }
44
+ if method == "leiden":
45
+ return {
46
  "max_size": params.leiden_max_size,
47
  "use_lcc": params.leiden_use_lcc,
48
  "random_seed": params.leiden_random_seed,
49
  }
50
+ # ece
51
+ return {
52
+ "max_units_per_community": params.ece_max_units,
53
+ "min_units_per_community": params.ece_min_units,
54
+ "max_tokens_per_community": params.ece_max_tokens,
55
+ "unit_sampling": params.ece_unit_sampling,
56
+ }
57
+
58
+
59
+ # pylint: disable=too-many-statements
60
+ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
61
+ # 1. Setup Workspace
62
+ log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
63
+ driver_logger = set_logger(log_file, "GraphGeb", if_stream=True)
64
+ CURRENT_LOGGER_VAR.set(driver_logger)
65
+
66
+ # 2. Setup Environment Variables for Ray Actors/LLM Init
67
+ # The refactored code relies on env vars in graphgen/common/init_llm.py
68
+ os.environ["SYNTHESIZER_BACKEND"] = "openai_api" # Assuming OpenAI compatible API
69
+ os.environ["SYNTHESIZER_BASE_URL"] = params.synthesizer_url
70
+ os.environ["SYNTHESIZER_API_KEY"] = params.api_key
71
+ os.environ["SYNTHESIZER_MODEL"] = params.synthesizer_model
72
+ os.environ["RPM"] = str(params.rpm)
73
+ os.environ["TPM"] = str(params.tpm)
74
+ os.environ["TOKENIZER_MODEL"] = params.tokenizer
75
+
76
+ if params.if_trainee_model:
77
+ os.environ["TRAINEE_BACKEND"] = "openai_api"
78
+ os.environ["TRAINEE_BASE_URL"] = params.trainee_url
79
+ os.environ["TRAINEE_API_KEY"] = params.trainee_api_key
80
+ os.environ["TRAINEE_MODEL"] = params.trainee_model
81
 
82
+ # 3. Construct Pipeline Configuration (DAG)
83
+ nodes = [
84
  {
85
+ "id": "read",
86
+ "op_name": "read",
87
+ "type": "source",
88
+ "dependencies": [],
89
  "params": {
90
+ "input_path": [params.upload_file],
91
  },
92
  },
93
  {
94
+ "id": "chunk",
95
+ "op_name": "chunk",
96
+ "type": "map_batch",
97
+ "dependencies": ["read"],
98
+ "execution_params": {"replicas": 1},
99
  "params": {
100
  "chunk_size": params.chunk_size,
101
  "chunk_overlap": params.chunk_overlap,
102
  },
103
  },
104
  {
105
+ "id": "build_kg",
106
+ "op_name": "build_kg",
107
+ "type": "map_batch",
108
+ "dependencies": ["chunk"],
109
+ "execution_params": {"replicas": 1, "batch_size": 128},
110
  },
111
  ]
112
 
113
+ last_node_id = "build_kg"
114
+
115
+ # Optional: Quiz and Judge
116
  if params.if_trainee_model:
117
+ nodes.append(
 
 
 
 
 
 
 
 
118
  {
119
+ "id": "quiz",
120
+ "op_name": "quiz",
121
+ "type": "aggregate", # QuizService uses aggregate in config
122
+ "dependencies": ["build_kg"],
123
+ "execution_params": {"replicas": 1, "batch_size": 128},
124
  "params": {
125
+ "quiz_samples": params.quiz_samples,
126
+ "concurrency_limit": 200,
127
  },
128
  }
129
  )
130
+
131
+ nodes.append(
132
  {
133
+ "id": "judge",
134
+ "op_name": "judge",
135
+ "type": "map_batch",
136
+ "dependencies": ["quiz"],
137
+ "execution_params": {"replicas": 1, "batch_size": 128},
 
 
138
  }
139
  )
140
+ last_node_id = "judge"
141
+
142
+ # Node: Partition
143
+ nodes.append(
144
+ {
145
+ "id": "partition",
146
+ "op_name": "partition",
147
+ "type": "aggregate", # PartitionService uses aggregate
148
+ "dependencies": [last_node_id],
149
+ "params": {
150
+ "method": params.partition_method,
151
+ "method_params": _get_partition_params(params),
152
+ },
153
+ }
154
+ )
155
+
156
+ # Node: Generate
157
+ nodes.append(
158
  {
159
+ "id": "generate",
160
+ "op_name": "generate",
161
+ "type": "map_batch",
162
+ "dependencies": ["partition"],
163
+ "execution_params": {"replicas": 1, "batch_size": 128},
164
  "params": {
165
  "method": params.mode,
166
  "data_format": params.data_format,
 
168
  }
169
  )
170
 
171
+ config = {"global_params": {"working_dir": working_dir}, "nodes": nodes}
 
 
 
 
172
 
173
+ try:
174
+ # 4. Initialize and Run Engine
175
+ # Initialize Ray if not already running (Engine handles this mostly, but good for safety)
176
+ if not ray.is_initialized():
177
+ ray.init(ignore_reinit_error=True, log_to_driver=True)
 
 
 
 
 
 
178
 
179
+ engine = Engine(config, operators)
 
 
 
 
 
 
 
 
 
180
 
181
+ # Start with an empty dataset to kick off the pipeline
182
+ ds = ray.data.from_items([])
 
 
183
 
184
+ # Execute pipeline
185
+ results = engine.execute(ds)
186
+
187
+ # 5. Process Output
188
+ # Extract the result from the 'generate' node
189
+ if "generate" in results:
190
+ result_ds = results["generate"]
191
+
192
+ # Create a temporary file to save the output
193
+ with tempfile.NamedTemporaryFile(
194
+ mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
195
+ ) as tmpfile:
196
+ # Iterate over rows and write to file
197
+ for row in result_ds.iter_rows():
198
+ json.dump(row, tmpfile, ensure_ascii=False)
199
+ tmpfile.write("\n")
200
+ output_file = tmpfile.name
201
+ else:
202
+ raise gr.Error("Generation step failed to produce output.")
203
+
204
+ # Note: Dynamic token counting from distributed actors is not directly available
205
+ # via client properties in the new architecture. We return the estimated stats from input.
206
+
207
+ return output_file, params.token_counter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  except Exception as e: # pylint: disable=broad-except
210
  raise gr.Error(f"Error occurred: {str(e)}")
211
 
212
  finally:
213
  # Clean up workspace
214
+ cleanup_workspace(working_dir) # Optional: keep for debugging or enable
215
 
216
 
217
  with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
 
231
  ("简体中文", "zh"),
232
  ],
233
  value="en",
 
234
  render=False,
235
  container=False,
236
  elem_classes=["center-row"],
 
258
  os.path.join(root_dir, "webui", "translation.json"),
259
  lang_btn,
260
  placeholder_langs=["en", "zh"],
261
+ persistant=False,
262
  ):
263
  lang_btn.render()
264
 
 
664
  outputs=[output, token_counter],
665
  )
666
 
 
667
  if __name__ == "__main__":
668
  demo.queue(api_open=False, default_concurrency_limit=2)
669
  demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)
graphgen/bases/__init__.py CHANGED
@@ -2,15 +2,11 @@ from .base_extractor import BaseExtractor
2
  from .base_generator import BaseGenerator
3
  from .base_kg_builder import BaseKGBuilder
4
  from .base_llm_wrapper import BaseLLMWrapper
 
5
  from .base_partitioner import BasePartitioner
6
  from .base_reader import BaseReader
7
  from .base_searcher import BaseSearcher
8
  from .base_splitter import BaseSplitter
9
- from .base_storage import (
10
- BaseGraphStorage,
11
- BaseKVStorage,
12
- BaseListStorage,
13
- StorageNameSpace,
14
- )
15
  from .base_tokenizer import BaseTokenizer
16
- from .datatypes import Chunk, QAPair, Token
 
2
  from .base_generator import BaseGenerator
3
  from .base_kg_builder import BaseKGBuilder
4
  from .base_llm_wrapper import BaseLLMWrapper
5
+ from .base_operator import BaseOperator
6
  from .base_partitioner import BasePartitioner
7
  from .base_reader import BaseReader
8
  from .base_searcher import BaseSearcher
9
  from .base_splitter import BaseSplitter
10
+ from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
 
 
 
 
 
11
  from .base_tokenizer import BaseTokenizer
12
+ from .datatypes import Chunk, Config, Node, QAPair, Token
graphgen/bases/base_llm_wrapper.py CHANGED
@@ -72,9 +72,3 @@ class BaseLLMWrapper(abc.ABC):
72
 
73
  filtered = filtered.strip()
74
  return filtered if filtered else text.strip()
75
-
76
- def shutdown(self) -> None:
77
- """Shutdown the LLM engine if applicable."""
78
-
79
- def restart(self) -> None:
80
- """Reinitialize the LLM engine if applicable."""
 
72
 
73
  filtered = filtered.strip()
74
  return filtered if filtered else text.strip()
 
 
 
 
 
 
graphgen/bases/base_operator.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ from abc import ABC, abstractmethod
4
+ from typing import Iterable, Union
5
+
6
+ import pandas as pd
7
+ import ray
8
+
9
+ from graphgen.utils import CURRENT_LOGGER_VAR, set_logger
10
+
11
+
12
+ class BaseOperator(ABC):
13
+ def __init__(self, working_dir: str = "cache", op_name: str = None):
14
+ log_dir = os.path.join(working_dir, "logs")
15
+ self.op_name = op_name or self.__class__.__name__
16
+
17
+ try:
18
+ ctx = ray.get_runtime_context()
19
+ worker_id = ctx.get_actor_id() or ctx.get_worker_id()
20
+ worker_id_short = worker_id[-6:] if worker_id else "driver"
21
+ except Exception as e:
22
+ print(
23
+ "Warning: Could not get Ray worker ID, defaulting to 'local'. Exception:",
24
+ e,
25
+ )
26
+ worker_id_short = "local"
27
+
28
+ # e.g. cache/logs/ChunkService_a1b2c3.log
29
+ log_file = os.path.join(log_dir, f"{self.op_name}_{worker_id_short}.log")
30
+
31
+ self.logger = set_logger(
32
+ log_file=log_file, name=f"{self.op_name}.{worker_id_short}", force=True
33
+ )
34
+
35
+ self.logger.info(
36
+ "[%s] Operator initialized on Worker %s", self.op_name, worker_id_short
37
+ )
38
+
39
+ def __call__(
40
+ self, batch: pd.DataFrame
41
+ ) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]:
42
+ logger_token = CURRENT_LOGGER_VAR.set(self.logger)
43
+ try:
44
+ result = self.process(batch)
45
+ if inspect.isgenerator(result):
46
+ yield from result
47
+ else:
48
+ yield result
49
+ finally:
50
+ CURRENT_LOGGER_VAR.reset(logger_token)
51
+
52
+ @abstractmethod
53
+ def process(self, batch):
54
+ raise NotImplementedError("Subclasses must implement the process method.")
55
+
56
+ def get_logger(self):
57
+ return self.logger
graphgen/bases/base_partitioner.py CHANGED
@@ -7,7 +7,7 @@ from graphgen.bases.datatypes import Community
7
 
8
  class BasePartitioner(ABC):
9
  @abstractmethod
10
- async def partition(
11
  self,
12
  g: BaseGraphStorage,
13
  **kwargs: Any,
@@ -20,39 +20,34 @@ class BasePartitioner(ABC):
20
  """
21
 
22
  @staticmethod
23
- async def community2batch(
24
- communities: List[Community], g: BaseGraphStorage
25
- ) -> list[
26
- tuple[
27
- list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
28
- ]
29
  ]:
30
  """
31
  Convert communities to batches of nodes and edges.
32
- :param communities
33
  :param g: Graph storage instance
34
  :return: List of batches, each batch is a tuple of (nodes, edges)
35
  """
36
- batches = []
37
- for comm in communities:
38
- nodes = comm.nodes
39
- edges = comm.edges
40
- nodes_data = []
41
- for node in nodes:
42
- node_data = g.get_node(node)
43
- if node_data:
44
- nodes_data.append((node, node_data))
45
- edges_data = []
46
- for u, v in edges:
47
- edge_data = g.get_edge(u, v)
 
 
48
  if edge_data:
49
- edges_data.append((u, v, edge_data))
50
- else:
51
- edge_data = g.get_edge(v, u)
52
- if edge_data:
53
- edges_data.append((v, u, edge_data))
54
- batches.append((nodes_data, edges_data))
55
- return batches
56
 
57
  @staticmethod
58
  def _build_adjacency_list(
 
7
 
8
  class BasePartitioner(ABC):
9
  @abstractmethod
10
+ def partition(
11
  self,
12
  g: BaseGraphStorage,
13
  **kwargs: Any,
 
20
  """
21
 
22
  @staticmethod
23
+ def community2batch(
24
+ comm: Community, g: BaseGraphStorage
25
+ ) -> tuple[
26
+ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
 
 
27
  ]:
28
  """
29
  Convert communities to batches of nodes and edges.
30
+ :param comm: Community
31
  :param g: Graph storage instance
32
  :return: List of batches, each batch is a tuple of (nodes, edges)
33
  """
34
+ nodes = comm.nodes
35
+ edges = comm.edges
36
+ nodes_data = []
37
+ for node in nodes:
38
+ node_data = g.get_node(node)
39
+ if node_data:
40
+ nodes_data.append((node, node_data))
41
+ edges_data = []
42
+ for u, v in edges:
43
+ edge_data = g.get_edge(u, v)
44
+ if edge_data:
45
+ edges_data.append((u, v, edge_data))
46
+ else:
47
+ edge_data = g.get_edge(v, u)
48
  if edge_data:
49
+ edges_data.append((v, u, edge_data))
50
+ return nodes_data, edges_data
 
 
 
 
 
51
 
52
  @staticmethod
53
  def _build_adjacency_list(
graphgen/bases/base_reader.py CHANGED
@@ -1,8 +1,10 @@
1
  import os
2
  from abc import ABC, abstractmethod
3
- from typing import Any, Dict, List
4
 
 
5
  import requests
 
6
 
7
 
8
  class BaseReader(ABC):
@@ -10,56 +12,70 @@ class BaseReader(ABC):
10
  Abstract base class for reading and processing data.
11
  """
12
 
13
- def __init__(self, text_column: str = "content"):
14
  self.text_column = text_column
 
15
 
16
  @abstractmethod
17
- def read(self, file_path: str) -> List[Dict[str, Any]]:
18
  """
19
  Read data from the specified file path.
20
 
21
- :param file_path: Path to the input file.
22
- :return: List of dictionaries containing the data.
23
  """
24
 
25
- @staticmethod
26
- def filter(data: List[dict]) -> List[dict]:
 
 
 
 
27
  """
28
- Filter out entries with empty or missing text in the specified column.
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- :param data: List of dictionaries containing the data.
31
- :return: Filtered list of dictionaries.
 
32
  """
 
 
33
 
34
- def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
35
- """
36
- Check if an image exists at the given local path or URL.
37
- :param path_or_url: Local file path or remote URL of the image.
38
- :param timeout: Timeout for remote URL requests in seconds.
39
- :return: True if the image exists, False otherwise.
40
- """
41
- if not path_or_url:
42
- return False
43
- if not path_or_url.startswith(("http://", "https://", "ftp://")):
44
- path = path_or_url.replace("file://", "", 1)
45
- path = os.path.abspath(path)
46
- return os.path.isfile(path)
47
- try:
48
- resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
49
- return resp.status_code == 200
50
- except requests.RequestException:
51
- return False
52
 
53
- filtered_data = []
54
- for item in data:
55
- if item.get("type") == "text":
56
- content = item.get("content", "").strip()
57
- if content:
58
- filtered_data.append(item)
59
- elif item.get("type") in ("image", "table", "equation"):
60
- img_path = item.get("img_path")
61
- if _image_exists(img_path):
62
- filtered_data.append(item)
63
- else:
64
- filtered_data.append(item)
65
- return filtered_data
 
 
 
 
 
 
 
 
 
1
  import os
2
  from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, List, Union
4
 
5
+ import pandas as pd
6
  import requests
7
+ from ray.data import Dataset
8
 
9
 
10
  class BaseReader(ABC):
 
12
  Abstract base class for reading and processing data.
13
  """
14
 
15
+ def __init__(self, text_column: str = "content", modalities: list = None):
16
  self.text_column = text_column
17
+ self.modalities = modalities if modalities is not None else ["text"]
18
 
19
  @abstractmethod
20
+ def read(self, input_path: Union[str, List[str]]) -> Dataset:
21
  """
22
  Read data from the specified file path.
23
 
24
+ :param input_path: Path to the input file or list of file paths.
25
+ :return: Ray Dataset containing the read data.
26
  """
27
 
28
+ def _should_keep_item(self, item: Dict[str, Any]) -> bool:
29
+ """
30
+ Determine whether to keep the given item based on the text column.
31
+
32
+ :param item: Dictionary representing a data entry.
33
+ :return: True if the item should be kept, False otherwise.
34
  """
35
+ item_type = item.get("type")
36
+ assert item_type in [
37
+ "text",
38
+ "image",
39
+ "table",
40
+ "equation",
41
+ "protein",
42
+ ], f"Unsupported item type: {item_type}"
43
+ if item_type == "text":
44
+ content = item.get(self.text_column, "").strip()
45
+ return bool(content)
46
+ return True
47
 
48
+ def _validate_batch(self, batch: pd.DataFrame) -> pd.DataFrame:
49
+ """
50
+ Validate data format.
51
  """
52
+ if "type" not in batch.columns:
53
+ raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}")
54
 
55
+ if "text" in batch["type"].values:
56
+ if self.text_column not in batch.columns:
57
+ raise ValueError(
58
+ f"Missing '{self.text_column}' column for text documents"
59
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ return batch
62
+
63
+ @staticmethod
64
+ def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
65
+ """
66
+ Check if an image exists at the given local path or URL.
67
+ :param path_or_url: Local file path or remote URL of the image.
68
+ :param timeout: Timeout for remote URL requests in seconds.
69
+ :return: True if the image exists, False otherwise.
70
+ """
71
+ if not path_or_url:
72
+ return False
73
+ if not path_or_url.startswith(("http://", "https://", "ftp://")):
74
+ path = path_or_url.replace("file://", "", 1)
75
+ path = os.path.abspath(path)
76
+ return os.path.isfile(path)
77
+ try:
78
+ resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
79
+ return resp.status_code == 200
80
+ except requests.RequestException:
81
+ return False
graphgen/bases/base_splitter.py CHANGED
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
4
  from typing import Callable, Iterable, List, Literal, Optional, Union
5
 
6
  from graphgen.bases.datatypes import Chunk
7
- from graphgen.utils import logger
8
 
9
 
10
  class BaseSplitter(ABC):
@@ -33,7 +33,7 @@ class BaseSplitter(ABC):
33
  """
34
  Split the input text into smaller chunks.
35
 
36
- :param text: The input text to be split.
37
  :return: A list of text chunks.
38
  """
39
 
@@ -111,7 +111,7 @@ class BaseSplitter(ABC):
111
  def _split_text_with_regex(
112
  text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
113
  ) -> List[str]:
114
- # Now that we have the separator, split the text
115
  if separator:
116
  if keep_separator:
117
  # The parentheses in the pattern keep the delimiters in the result.
 
4
  from typing import Callable, Iterable, List, Literal, Optional, Union
5
 
6
  from graphgen.bases.datatypes import Chunk
7
+ from graphgen.utils.log import logger
8
 
9
 
10
  class BaseSplitter(ABC):
 
33
  """
34
  Split the input text into smaller chunks.
35
 
36
+ :param text: The input text to be chunk.
37
  :return: A list of text chunks.
38
  """
39
 
 
111
  def _split_text_with_regex(
112
  text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
113
  ) -> List[str]:
114
+ # Now that we have the separator, chunk the text
115
  if separator:
116
  if keep_separator:
117
  # The parentheses in the pattern keep the delimiters in the result.
graphgen/bases/base_storage.py CHANGED
@@ -16,23 +16,6 @@ class StorageNameSpace:
16
  """commit the storage operations after querying"""
17
 
18
 
19
- class BaseListStorage(Generic[T], StorageNameSpace):
20
- def all_items(self) -> list[T]:
21
- raise NotImplementedError
22
-
23
- def get_by_index(self, index: int) -> Union[T, None]:
24
- raise NotImplementedError
25
-
26
- def append(self, data: T):
27
- raise NotImplementedError
28
-
29
- def upsert(self, data: list[T]):
30
- raise NotImplementedError
31
-
32
- def drop(self):
33
- raise NotImplementedError
34
-
35
-
36
  class BaseKVStorage(Generic[T], StorageNameSpace):
37
  def all_keys(self) -> list[str]:
38
  raise NotImplementedError
@@ -58,6 +41,9 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
58
  def drop(self):
59
  raise NotImplementedError
60
 
 
 
 
61
 
62
  class BaseGraphStorage(StorageNameSpace):
63
  def has_node(self, node_id: str) -> bool:
@@ -105,3 +91,6 @@ class BaseGraphStorage(StorageNameSpace):
105
 
106
  def delete_node(self, node_id: str):
107
  raise NotImplementedError
 
 
 
 
16
  """commit the storage operations after querying"""
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class BaseKVStorage(Generic[T], StorageNameSpace):
20
  def all_keys(self) -> list[str]:
21
  raise NotImplementedError
 
41
  def drop(self):
42
  raise NotImplementedError
43
 
44
+ def reload(self):
45
+ raise NotImplementedError
46
+
47
 
48
  class BaseGraphStorage(StorageNameSpace):
49
  def has_node(self, node_id: str) -> bool:
 
91
 
92
  def delete_node(self, node_id: str):
93
  raise NotImplementedError
94
+
95
+ def reload(self):
96
+ raise NotImplementedError
graphgen/bases/datatypes.py CHANGED
@@ -2,6 +2,8 @@ import math
2
  from dataclasses import dataclass, field
3
  from typing import List, Union
4
 
 
 
5
 
6
  @dataclass
7
  class Chunk:
@@ -48,3 +50,45 @@ class Community:
48
  nodes: List[str] = field(default_factory=list)
49
  edges: List[tuple] = field(default_factory=list)
50
  metadata: dict = field(default_factory=dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from dataclasses import dataclass, field
3
  from typing import List, Union
4
 
5
+ from pydantic import BaseModel, Field, field_validator
6
+
7
 
8
  @dataclass
9
  class Chunk:
 
50
  nodes: List[str] = field(default_factory=list)
51
  edges: List[tuple] = field(default_factory=list)
52
  metadata: dict = field(default_factory=dict)
53
+
54
+
55
+ class Node(BaseModel):
56
+ id: str = Field(..., description="unique node id")
57
+ op_name: str = Field(..., description="operator name")
58
+ type: str = Field(
59
+ ..., description="task type, e.g., map, filter, flatmap, aggregate, map_batch"
60
+ )
61
+ params: dict = Field(default_factory=dict, description="operator parameters")
62
+ dependencies: List[str] = Field(
63
+ default_factory=list, description="list of dependent node ids"
64
+ )
65
+ execution_params: dict = Field(
66
+ default_factory=dict, description="execution parameters like replicas, batch_size"
67
+ )
68
+
69
+ @classmethod
70
+ @field_validator("type")
71
+ def validate_type(cls, v: str) -> str:
72
+ valid_types = {"map", "filter", "flatmap", "aggregate", "map_batch"}
73
+ if v not in valid_types:
74
+ raise ValueError(f"Invalid node type: {v}. Must be one of {valid_types}.")
75
+ return v
76
+
77
+
78
+ class Config(BaseModel):
79
+ global_params: dict = Field(
80
+ default_factory=dict, description="global context for the computation graph"
81
+ )
82
+
83
+ nodes: List[Node] = Field(
84
+ ..., min_length=1, description="list of nodes in the computation graph"
85
+ )
86
+
87
+ @classmethod
88
+ @field_validator("nodes")
89
+ def validate_unique_ids(cls, v: List[Node]) -> List[Node]:
90
+ ids = [node.id for node in v]
91
+ if len(ids) != len(set(ids)):
92
+ duplicates = {id_ for id_ in ids if ids.count(id_) > 1}
93
+ raise ValueError(f"Duplicate node ids found: {duplicates}")
94
+ return v
graphgen/{operators/init → common}/__init__.py RENAMED
@@ -1 +1,2 @@
1
  from .init_llm import init_llm
 
 
1
  from .init_llm import init_llm
2
+ from .init_storage import init_storage
graphgen/{operators/init → common}/init_llm.py RENAMED
@@ -1,56 +1,152 @@
1
  import os
2
  from typing import Any, Dict, Optional
3
 
 
 
4
  from graphgen.bases import BaseLLMWrapper
 
5
  from graphgen.models import Tokenizer
6
 
7
 
8
- class LLMFactory:
9
  """
10
- A factory class to create LLM wrapper instances based on the specified backend.
11
- Supported backends include:
12
- - http_api: HTTPClient
13
- - openai_api: OpenAIClient
14
- - ollama_api: OllamaClient
15
- - huggingface: HuggingFaceWrapper
16
- - sglang: SGLangWrapper
17
  """
18
 
19
- @staticmethod
20
- def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
21
- # add tokenizer
22
- tokenizer: Tokenizer = Tokenizer(
23
- os.environ.get("TOKENIZER_MODEL", "cl100k_base"),
24
- )
25
  config["tokenizer"] = tokenizer
 
26
  if backend == "http_api":
27
  from graphgen.models.llm.api.http_client import HTTPClient
28
 
29
- return HTTPClient(**config)
30
- if backend in ("openai_api", "azure_openai_api"):
31
  from graphgen.models.llm.api.openai_client import OpenAIClient
 
32
  # pass in concrete backend to the OpenAIClient so that internally we can distinguish
33
  # between OpenAI and Azure OpenAI
34
- return OpenAIClient(**config, backend=backend)
35
- if backend == "ollama_api":
36
  from graphgen.models.llm.api.ollama_client import OllamaClient
37
 
38
- return OllamaClient(**config)
39
- if backend == "huggingface":
40
  from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper
41
 
42
- return HuggingFaceWrapper(**config)
43
- if backend == "sglang":
44
  from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper
45
 
46
- return SGLangWrapper(**config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # if backend == "vllm":
49
- # from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper
50
- #
51
- # return VLLMWrapper(**config)
52
 
53
- raise NotImplementedError(f"Backend {backend} is not implemented yet.")
54
 
55
 
56
  def _load_env_group(prefix: str) -> Dict[str, Any]:
@@ -77,5 +173,5 @@ def init_llm(model_type: str) -> Optional[BaseLLMWrapper]:
77
  if not config:
78
  return None
79
  backend = config.pop("backend")
80
- llm_wrapper = LLMFactory.create_llm_wrapper(backend, config)
81
  return llm_wrapper
 
1
  import os
2
  from typing import Any, Dict, Optional
3
 
4
+ import ray
5
+
6
  from graphgen.bases import BaseLLMWrapper
7
+ from graphgen.common.init_storage import get_actor_handle
8
  from graphgen.models import Tokenizer
9
 
10
 
11
+ class LLMServiceActor:
12
  """
13
+ A Ray actor class to wrap LLM wrapper instances for distributed usage.
 
 
 
 
 
 
14
  """
15
 
16
+ def __init__(self, backend: str, config: Dict[str, Any]):
17
+ self.backend = backend
18
+ tokenizer_model = os.environ.get("TOKENIZER_MODEL", "cl100k_base")
19
+ tokenizer = Tokenizer(model_name=tokenizer_model)
 
 
20
  config["tokenizer"] = tokenizer
21
+
22
  if backend == "http_api":
23
  from graphgen.models.llm.api.http_client import HTTPClient
24
 
25
+ self.llm_instance = HTTPClient(**config)
26
+ elif backend in ("openai_api", "azure_openai_api"):
27
  from graphgen.models.llm.api.openai_client import OpenAIClient
28
+
29
  # pass in concrete backend to the OpenAIClient so that internally we can distinguish
30
  # between OpenAI and Azure OpenAI
31
+ self.llm_instance = OpenAIClient(**config, backend=backend)
32
+ elif backend == "ollama_api":
33
  from graphgen.models.llm.api.ollama_client import OllamaClient
34
 
35
+ self.llm_instance = OllamaClient(**config)
36
+ elif backend == "huggingface":
37
  from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper
38
 
39
+ self.llm_instance = HuggingFaceWrapper(**config)
40
+ elif backend == "sglang":
41
  from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper
42
 
43
+ self.llm_instance = SGLangWrapper(**config)
44
+
45
+ elif backend == "vllm":
46
+ from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper
47
+
48
+ self.llm_instance = VLLMWrapper(**config)
49
+ else:
50
+ raise NotImplementedError(f"Backend {backend} is not implemented yet.")
51
+
52
+ async def generate_answer(
53
+ self, text: str, history: Optional[list[str]] = None, **extra: Any
54
+ ) -> str:
55
+ return await self.llm_instance.generate_answer(text, history, **extra)
56
+
57
+ async def generate_topk_per_token(
58
+ self, text: str, history: Optional[list[str]] = None, **extra: Any
59
+ ) -> list:
60
+ return await self.llm_instance.generate_topk_per_token(text, history, **extra)
61
+
62
+ async def generate_inputs_prob(
63
+ self, text: str, history: Optional[list[str]] = None, **extra: Any
64
+ ) -> list:
65
+ return await self.llm_instance.generate_inputs_prob(text, history, **extra)
66
+
67
+ def ready(self) -> bool:
68
+ """A simple method to check if the actor is ready."""
69
+ return True
70
+
71
+
72
+ class LLMServiceProxy(BaseLLMWrapper):
73
+ """
74
+ A proxy class to interact with the LLMServiceActor for distributed LLM operations.
75
+ """
76
+
77
+ def __init__(self, actor_name: str):
78
+ super().__init__()
79
+ self.actor_handle = get_actor_handle(actor_name)
80
+ self._create_local_tokenizer()
81
+
82
+ async def generate_answer(
83
+ self, text: str, history: Optional[list[str]] = None, **extra: Any
84
+ ) -> str:
85
+ object_ref = self.actor_handle.generate_answer.remote(text, history, **extra)
86
+ return await object_ref
87
+
88
+ async def generate_topk_per_token(
89
+ self, text: str, history: Optional[list[str]] = None, **extra: Any
90
+ ) -> list:
91
+ object_ref = self.actor_handle.generate_topk_per_token.remote(
92
+ text, history, **extra
93
+ )
94
+ return await object_ref
95
+
96
+ async def generate_inputs_prob(
97
+ self, text: str, history: Optional[list[str]] = None, **extra: Any
98
+ ) -> list:
99
+ object_ref = self.actor_handle.generate_inputs_prob.remote(
100
+ text, history, **extra
101
+ )
102
+ return await object_ref
103
+
104
+ def _create_local_tokenizer(self):
105
+ tokenizer_model = os.environ.get("TOKENIZER_MODEL", "cl100k_base")
106
+ self.tokenizer = Tokenizer(model_name=tokenizer_model)
107
+
108
+
109
+ class LLMFactory:
110
+ """
111
+ A factory class to create LLM wrapper instances based on the specified backend.
112
+ Supported backends include:
113
+ - http_api: HTTPClient
114
+ - openai_api: OpenAIClient
115
+ - ollama_api: OllamaClient
116
+ - huggingface: HuggingFaceWrapper
117
+ - sglang: SGLangWrapper
118
+ """
119
+
120
+ @staticmethod
121
+ def create_llm(
122
+ model_type: str, backend: str, config: Dict[str, Any]
123
+ ) -> BaseLLMWrapper:
124
+ if not config:
125
+ raise ValueError(
126
+ f"No configuration provided for LLM {model_type} with backend {backend}."
127
+ )
128
+
129
+ actor_name = f"Actor_LLM_{model_type}"
130
+ try:
131
+ ray.get_actor(actor_name)
132
+ except ValueError:
133
+ print(f"Creating Ray actor for LLM {model_type} with backend {backend}.")
134
+ num_gpus = int(config.pop("num_gpus", 0))
135
+ actor = (
136
+ ray.remote(LLMServiceActor)
137
+ .options(
138
+ name=actor_name,
139
+ num_gpus=num_gpus,
140
+ lifetime="detached",
141
+ get_if_exists=True,
142
+ )
143
+ .remote(backend, config)
144
+ )
145
 
146
+ # wait for actor to be ready
147
+ ray.get(actor.ready.remote())
 
 
148
 
149
+ return LLMServiceProxy(actor_name)
150
 
151
 
152
  def _load_env_group(prefix: str) -> Dict[str, Any]:
 
173
  if not config:
174
  return None
175
  backend = config.pop("backend")
176
+ llm_wrapper = LLMFactory.create_llm(model_type, backend, config)
177
  return llm_wrapper
graphgen/common/init_storage.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Union
2
+
3
+ import ray
4
+
5
+ from graphgen.bases.base_storage import BaseGraphStorage, BaseKVStorage
6
+
7
+
8
+ class KVStorageActor:
9
+ def __init__(self, backend: str, working_dir: str, namespace: str):
10
+ if backend == "json_kv":
11
+ from graphgen.models import JsonKVStorage
12
+
13
+ self.kv = JsonKVStorage(working_dir, namespace)
14
+ elif backend == "rocksdb":
15
+ from graphgen.models import RocksDBKVStorage
16
+
17
+ self.kv = RocksDBKVStorage(working_dir, namespace)
18
+ else:
19
+ raise ValueError(f"Unknown KV backend: {backend}")
20
+
21
+ def data(self) -> Dict[str, Dict]:
22
+ return self.kv.data
23
+
24
+ def all_keys(self) -> list[str]:
25
+ return self.kv.all_keys()
26
+
27
+ def index_done_callback(self):
28
+ return self.kv.index_done_callback()
29
+
30
+ def get_by_id(self, id: str) -> Dict:
31
+ return self.kv.get_by_id(id)
32
+
33
+ def get_by_ids(self, ids: list[str], fields=None) -> list:
34
+ return self.kv.get_by_ids(ids, fields)
35
+
36
+ def get_all(self) -> Dict[str, Dict]:
37
+ return self.kv.get_all()
38
+
39
+ def filter_keys(self, data: list[str]) -> set[str]:
40
+ return self.kv.filter_keys(data)
41
+
42
+ def upsert(self, data: dict) -> dict:
43
+ return self.kv.upsert(data)
44
+
45
+ def drop(self):
46
+ return self.kv.drop()
47
+
48
+ def reload(self):
49
+ return self.kv.reload()
50
+
51
+
52
+ class GraphStorageActor:
53
+ def __init__(self, backend: str, working_dir: str, namespace: str):
54
+ if backend == "networkx":
55
+ from graphgen.models import NetworkXStorage
56
+
57
+ self.graph = NetworkXStorage(working_dir, namespace)
58
+ elif backend == "kuzu":
59
+ from graphgen.models import KuzuStorage
60
+
61
+ self.graph = KuzuStorage(working_dir, namespace)
62
+ else:
63
+ raise ValueError(f"Unknown Graph backend: {backend}")
64
+
65
+ def index_done_callback(self):
66
+ return self.graph.index_done_callback()
67
+
68
+ def has_node(self, node_id: str) -> bool:
69
+ return self.graph.has_node(node_id)
70
+
71
+ def has_edge(self, source_node_id: str, target_node_id: str):
72
+ return self.graph.has_edge(source_node_id, target_node_id)
73
+
74
+ def node_degree(self, node_id: str) -> int:
75
+ return self.graph.node_degree(node_id)
76
+
77
+ def edge_degree(self, src_id: str, tgt_id: str) -> int:
78
+ return self.graph.edge_degree(src_id, tgt_id)
79
+
80
+ def get_node(self, node_id: str) -> Any:
81
+ return self.graph.get_node(node_id)
82
+
83
+ def update_node(self, node_id: str, node_data: dict[str, str]):
84
+ return self.graph.update_node(node_id, node_data)
85
+
86
+ def get_all_nodes(self) -> Any:
87
+ return self.graph.get_all_nodes()
88
+
89
+ def get_edge(self, source_node_id: str, target_node_id: str):
90
+ return self.graph.get_edge(source_node_id, target_node_id)
91
+
92
+ def update_edge(
93
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
94
+ ):
95
+ return self.graph.update_edge(source_node_id, target_node_id, edge_data)
96
+
97
+ def get_all_edges(self) -> Any:
98
+ return self.graph.get_all_edges()
99
+
100
+ def get_node_edges(self, source_node_id: str) -> Any:
101
+ return self.graph.get_node_edges(source_node_id)
102
+
103
+ def upsert_node(self, node_id: str, node_data: dict[str, str]):
104
+ return self.graph.upsert_node(node_id, node_data)
105
+
106
+ def upsert_edge(
107
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
108
+ ):
109
+ return self.graph.upsert_edge(source_node_id, target_node_id, edge_data)
110
+
111
+ def delete_node(self, node_id: str):
112
+ return self.graph.delete_node(node_id)
113
+
114
+ def reload(self):
115
+ return self.graph.reload()
116
+
117
+
118
+ def get_actor_handle(name: str):
119
+ try:
120
+ return ray.get_actor(name)
121
+ except ValueError as exc:
122
+ raise RuntimeError(
123
+ f"Actor {name} not found. Make sure it is created before accessing."
124
+ ) from exc
125
+
126
+
127
+ class RemoteKVStorageProxy(BaseKVStorage):
128
+ def __init__(self, namespace: str):
129
+ super().__init__()
130
+ self.namespace = namespace
131
+ self.actor_name = f"Actor_KV_{namespace}"
132
+ self.actor = get_actor_handle(self.actor_name)
133
+
134
+ def data(self) -> Dict[str, Any]:
135
+ return ray.get(self.actor.data.remote())
136
+
137
+ def all_keys(self) -> list[str]:
138
+ return ray.get(self.actor.all_keys.remote())
139
+
140
+ def index_done_callback(self):
141
+ return ray.get(self.actor.index_done_callback.remote())
142
+
143
+ def get_by_id(self, id: str) -> Union[Any, None]:
144
+ return ray.get(self.actor.get_by_id.remote(id))
145
+
146
+ def get_by_ids(self, ids: list[str], fields=None) -> list[Any]:
147
+ return ray.get(self.actor.get_by_ids.remote(ids, fields))
148
+
149
+ def get_all(self) -> Dict[str, Any]:
150
+ return ray.get(self.actor.get_all.remote())
151
+
152
+ def filter_keys(self, data: list[str]) -> set[str]:
153
+ return ray.get(self.actor.filter_keys.remote(data))
154
+
155
+ def upsert(self, data: Dict[str, Any]):
156
+ return ray.get(self.actor.upsert.remote(data))
157
+
158
+ def drop(self):
159
+ return ray.get(self.actor.drop.remote())
160
+
161
+ def reload(self):
162
+ return ray.get(self.actor.reload.remote())
163
+
164
+
165
+ class RemoteGraphStorageProxy(BaseGraphStorage):
166
+ def __init__(self, namespace: str):
167
+ super().__init__()
168
+ self.namespace = namespace
169
+ self.actor_name = f"Actor_Graph_{namespace}"
170
+ self.actor = get_actor_handle(self.actor_name)
171
+
172
+ def index_done_callback(self):
173
+ return ray.get(self.actor.index_done_callback.remote())
174
+
175
+ def has_node(self, node_id: str) -> bool:
176
+ return ray.get(self.actor.has_node.remote(node_id))
177
+
178
+ def has_edge(self, source_node_id: str, target_node_id: str):
179
+ return ray.get(self.actor.has_edge.remote(source_node_id, target_node_id))
180
+
181
+ def node_degree(self, node_id: str) -> int:
182
+ return ray.get(self.actor.node_degree.remote(node_id))
183
+
184
+ def edge_degree(self, src_id: str, tgt_id: str) -> int:
185
+ return ray.get(self.actor.edge_degree.remote(src_id, tgt_id))
186
+
187
+ def get_node(self, node_id: str) -> Any:
188
+ return ray.get(self.actor.get_node.remote(node_id))
189
+
190
+ def update_node(self, node_id: str, node_data: dict[str, str]):
191
+ return ray.get(self.actor.update_node.remote(node_id, node_data))
192
+
193
+ def get_all_nodes(self) -> Any:
194
+ return ray.get(self.actor.get_all_nodes.remote())
195
+
196
+ def get_edge(self, source_node_id: str, target_node_id: str):
197
+ return ray.get(self.actor.get_edge.remote(source_node_id, target_node_id))
198
+
199
+ def update_edge(
200
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
201
+ ):
202
+ return ray.get(
203
+ self.actor.update_edge.remote(source_node_id, target_node_id, edge_data)
204
+ )
205
+
206
+ def get_all_edges(self) -> Any:
207
+ return ray.get(self.actor.get_all_edges.remote())
208
+
209
+ def get_node_edges(self, source_node_id: str) -> Any:
210
+ return ray.get(self.actor.get_node_edges.remote(source_node_id))
211
+
212
+ def upsert_node(self, node_id: str, node_data: dict[str, str]):
213
+ return ray.get(self.actor.upsert_node.remote(node_id, node_data))
214
+
215
+ def upsert_edge(
216
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
217
+ ):
218
+ return ray.get(
219
+ self.actor.upsert_edge.remote(source_node_id, target_node_id, edge_data)
220
+ )
221
+
222
+ def delete_node(self, node_id: str):
223
+ return ray.get(self.actor.delete_node.remote(node_id))
224
+
225
+ def reload(self):
226
+ return ray.get(self.actor.reload.remote())
227
+
228
+
229
+ class StorageFactory:
230
+ """
231
+ Factory class to create storage instances based on backend.
232
+ """
233
+
234
+ @staticmethod
235
+ def create_storage(backend: str, working_dir: str, namespace: str):
236
+ if backend in ["json_kv", "rocksdb"]:
237
+ actor_name = f"Actor_KV_{namespace}"
238
+ try:
239
+ ray.get_actor(actor_name)
240
+ except ValueError:
241
+ ray.remote(KVStorageActor).options(
242
+ name=actor_name,
243
+ lifetime="detached",
244
+ get_if_exists=True,
245
+ ).remote(backend, working_dir, namespace)
246
+ return RemoteKVStorageProxy(namespace)
247
+ if backend in ["networkx", "kuzu"]:
248
+ actor_name = f"Actor_Graph_{namespace}"
249
+ try:
250
+ ray.get_actor(actor_name)
251
+ except ValueError:
252
+ ray.remote(GraphStorageActor).options(
253
+ name=actor_name,
254
+ lifetime="detached",
255
+ get_if_exists=True,
256
+ ).remote(backend, working_dir, namespace)
257
+ return RemoteGraphStorageProxy(namespace)
258
+ raise ValueError(f"Unknown storage backend: {backend}")
259
+
260
+
261
+ def init_storage(backend: str, working_dir: str, namespace: str):
262
+ return StorageFactory.create_storage(backend, working_dir, namespace)
graphgen/configs/aggregated_config.yaml DELETED
@@ -1,41 +0,0 @@
1
- pipeline:
2
- - name: read_step # step name is unique in the pipeline, and can be referenced by other steps
3
- op_key: read
4
- params:
5
- input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
6
-
7
- - name: chunk_step
8
- op_key: chunk
9
- deps: [read_step] # chunk_step depends on read_step
10
- params:
11
- chunk_size: 1024 # chunk size for text splitting
12
- chunk_overlap: 100 # chunk overlap for text splitting
13
-
14
- - name: build_kg_step
15
- op_key: build_kg
16
- deps: [chunk_step] # build_kg_step depends on chunk_step
17
-
18
- - name: quiz_and_judge_step
19
- op_key: quiz_and_judge
20
- deps: [build_kg_step] # quiz_and_judge depends on build_kg_step
21
- params:
22
- quiz_samples: 2 # number of quiz samples to generate
23
- re_judge: false # whether to re-judge the existing quiz samples
24
-
25
- - name: partition_step
26
- op_key: partition
27
- deps: [quiz_and_judge_step] # partition_step depends on quiz_and_judge_step
28
- params:
29
- method: ece # ece is a custom partition method based on comprehension loss
30
- method_params:
31
- max_units_per_community: 20 # max nodes and edges per community
32
- min_units_per_community: 5 # min nodes and edges per community
33
- max_tokens_per_community: 10240 # max tokens per community
34
- unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
35
-
36
- - name: generate_step
37
- op_key: generate
38
- deps: [partition_step] # generate_step depends on partition_step
39
- params:
40
- method: aggregated # atomic, aggregated, multi_hop, cot, vqa
41
- data_format: ChatML # Alpaca, Sharegpt, ChatML
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/configs/atomic_config.yaml DELETED
@@ -1,31 +0,0 @@
1
- pipeline:
2
- - name: read_step
3
- op_key: read
4
- params:
5
- input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples
6
-
7
- - name: chunk_step
8
- op_key: chunk
9
- deps: [read_step] # chunk_step depends on read_step
10
- params:
11
- chunk_size: 1024 # chunk size for text splitting
12
- chunk_overlap: 100 # chunk overlap for text splitting
13
-
14
- - name: build_kg_step
15
- op_key: build_kg
16
- deps: [chunk_step] # build_kg depends on chunk_step
17
-
18
- - name: partition_step
19
- op_key: partition
20
- deps: [build_kg] # partition_step depends on build_kg
21
- params:
22
- method: dfs # partition method, support: dfs, bfs, ece, leiden
23
- method_params:
24
- max_units_per_community: 1 # atomic partition, one node or edge per community
25
-
26
- - name: generate_step
27
- op_key: generate
28
- deps: [partition_step] # generate_step depends on partition_step
29
- params:
30
- method: atomic # atomic, aggregated, multi_hop, cot, vqa
31
- data_format: Alpaca # Alpaca, Sharegpt, ChatML
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/configs/cot_config.yaml DELETED
@@ -1,33 +0,0 @@
1
- pipeline:
2
- - name: read_step
3
- op_key: read
4
- params:
5
- input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
6
-
7
- - name: chunk_step
8
- op_key: chunk
9
- deps: [read_step] # chunk_step depends on read_step
10
- params:
11
- chunk_size: 1024 # chunk size for text splitting
12
- chunk_overlap: 100 # chunk overlap for text splitting
13
-
14
- - name: build_kg_step
15
- op_key: build_kg
16
- deps: [chunk_step] # build_kg depends on chunk_step
17
-
18
- - name: partition_step
19
- op_key: partition
20
- deps: [build_kg_step] # partition_step depends on build_kg
21
- params:
22
- method: leiden # leiden is a partitioner detection algorithm
23
- method_params:
24
- max_size: 20 # Maximum size of communities
25
- use_lcc: false # whether to use the largest connected component
26
- random_seed: 42 # random seed for partitioning
27
-
28
- - name: generate_step
29
- op_key: generate
30
- deps: [partition_step] # generate_step depends on partition_step
31
- params:
32
- method: cot # atomic, aggregated, multi_hop, cot, vqa
33
- data_format: Sharegpt # Alpaca, Sharegpt, ChatML
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/configs/multi_hop_config.yaml DELETED
@@ -1,34 +0,0 @@
1
- pipeline:
2
- - name: read_step
3
- op_key: read
4
- params:
5
- input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
6
-
7
- - name: chunk_step
8
- op_key: chunk
9
- deps: [read_step] # chunk_step depends on read_step
10
- params:
11
- chunk_size: 1024 # chunk size for text splitting
12
- chunk_overlap: 100 # chunk overlap for text splitting
13
-
14
- - name: build_kg_step
15
- op_key: build_kg
16
- deps: [chunk_step] # build_kg_step depends on chunk_step
17
-
18
- - name: partition_step
19
- op_key: partition
20
- deps: [build_kg_step] # partition_step depends on build_kg_step
21
- params:
22
- method: ece # ece is a custom partition method based on comprehension loss
23
- method_params:
24
- max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3
25
- min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3
26
- max_tokens_per_community: 10240 # max tokens per community
27
- unit_sampling: random # unit sampling strategy, support: random, max_loss, min_loss
28
-
29
- - name: generate_step
30
- op_key: generate
31
- deps: [partition_step] # generate_step depends on partition_step
32
- params:
33
- method: multi_hop # atomic, aggregated, multi_hop, cot, vqa
34
- data_format: ChatML # Alpaca, Sharegpt, ChatML
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/configs/schema_guided_extraction_config.yaml DELETED
@@ -1,20 +0,0 @@
1
- pipeline:
2
- - name: read_step
3
- op_key: read
4
- params:
5
- input_file: resources/input_examples/extract_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
6
-
7
- - name: chunk_step
8
- op_key: chunk
9
- deps: [read_step] # chunk_step depends on read_step
10
- params:
11
- chunk_size: 20480
12
- chunk_overlap: 2000
13
- separators: []
14
-
15
- - name: extract_step
16
- op_key: extract
17
- deps: [chunk_step] # extract_step depends on chunk_step
18
- params:
19
- method: schema_guided # extraction method, support: schema_guided
20
- schema_file: graphgen/templates/extraction/schemas/legal_contract.json # schema file path for schema_guided method
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/configs/search_dna_config.yaml DELETED
@@ -1,17 +0,0 @@
1
- pipeline:
2
- - name: read_step
3
- op_key: read
4
- params:
5
- input_file: resources/input_examples/search_dna_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
6
-
7
- - name: search_step
8
- op_key: search
9
- deps: [read_step] # search_step depends on read_step
10
- params:
11
- data_sources: [ncbi] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
12
- ncbi_params:
13
- email: test@example.com # NCBI requires an email address
14
- tool: GraphGen # tool name for NCBI API
15
- use_local_blast: true # whether to use local blast for DNA search
16
- local_blast_db: refseq_release/refseq_release # path to local BLAST database (without .nhr extension)
17
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/configs/search_protein_config.yaml DELETED
@@ -1,15 +0,0 @@
1
- pipeline:
2
- - name: read_step
3
- op_key: read
4
- params:
5
- input_file: resources/input_examples/search_protein_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
6
-
7
- - name: search_step
8
- op_key: search
9
- deps: [read_step] # search_step depends on read_step
10
- params:
11
- data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot
12
- uniprot_params:
13
- use_local_blast: true # whether to use local blast for uniprot search
14
- local_blast_db: /your_path/2024_01/uniprot_sprot # format: /path/to/${RELEASE}/uniprot_sprot
15
- # options: uniprot_sprot (recommended, high quality), uniprot_trembl, or uniprot_${RELEASE} (merged database)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/configs/search_rna_config.yaml DELETED
@@ -1,14 +0,0 @@
1
- pipeline:
2
- - name: read_step
3
- op_key: read
4
- params:
5
- input_file: resources/input_examples/search_rna_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
6
-
7
- - name: search_step
8
- op_key: search
9
- deps: [read_step] # search_step depends on read_step
10
- params:
11
- data_sources: [rnacentral] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
12
- rnacentral_params:
13
- use_local_blast: true # whether to use local blast for RNA search
14
- local_blast_db: rnacentral_ensembl_gencode_YYYYMMDD/ensembl_gencode_YYYYMMDD # path to local BLAST database (without .nhr extension)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/configs/vqa_config.yaml DELETED
@@ -1,32 +0,0 @@
1
- pipeline:
2
- - name: read_step
3
- op_key: read
4
- params:
5
- input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
6
-
7
- - name: chunk_step
8
- op_key: chunk
9
- deps: [read_step] # chunk_step depends on read_step
10
- params:
11
- chunk_size: 1024 # chunk size for text splitting
12
- chunk_overlap: 100 # chunk overlap for text splitting
13
-
14
- - name: build_kg_step
15
- op_key: build_kg
16
- deps: [chunk_step] # build_kg depends on chunk_step
17
-
18
- - name: partition_step
19
- op_key: partition
20
- deps: [build_kg_step] # partition_step depends on build_kg_step
21
- params:
22
- method: anchor_bfs # partition method
23
- method_params:
24
- anchor_type: image # node type to select anchor nodes
25
- max_units_per_community: 10 # atomic partition, one node or edge per community
26
-
27
- - name: generate_step
28
- op_key: generate
29
- deps: [partition_step] # generate_step depends on partition_step
30
- params:
31
- method: vqa # atomic, aggregated, multi_hop, cot, vqa
32
- data_format: ChatML # Alpaca, Sharegpt, ChatML
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/engine.py CHANGED
@@ -1,125 +1,210 @@
1
- """
2
- orchestration engine for GraphGen
3
- """
 
 
4
 
5
- import threading
6
- import traceback
7
- from typing import Any, Callable, List
8
 
 
 
9
 
10
- class Context(dict):
11
- _lock = threading.Lock()
12
 
13
- def set(self, k, v):
14
- with self._lock:
15
- self[k] = v
16
-
17
- def get(self, k, default=None):
18
- with self._lock:
19
- return super().get(k, default)
20
-
21
-
22
- class OpNode:
23
  def __init__(
24
- self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any]
25
  ):
26
- self.name, self.deps, self.func = name, deps, func
27
-
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- class Engine:
30
- def __init__(self, max_workers: int = 4):
31
- self.max_workers = max_workers
32
-
33
- def run(self, ops: List[OpNode], ctx: Context):
34
- self._validate(ops)
35
- name2op = {operation.name: operation for operation in ops}
36
-
37
- # topological sort
38
- graph = {n: set(name2op[n].deps) for n in name2op}
39
- topo = []
40
- q = [n for n, d in graph.items() if not d]
41
- while q:
42
- cur = q.pop(0)
43
- topo.append(cur)
44
- for child in [c for c, d in graph.items() if cur in d]:
45
- graph[child].remove(cur)
46
- if not graph[child]:
47
- q.append(child)
48
-
49
- if len(topo) != len(ops):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  raise ValueError(
51
- "Cyclic dependencies detected among operations."
52
- "Please check your configuration."
53
  )
54
 
55
- # semaphore for max_workers
56
- sem = threading.Semaphore(self.max_workers)
57
- done = {n: threading.Event() for n in name2op}
58
- exc = {}
59
-
60
- def _exec(n: str):
61
- with sem:
62
- for d in name2op[n].deps:
63
- done[d].wait()
64
- if any(d in exc for d in name2op[n].deps):
65
- exc[n] = Exception("Skipped due to failed dependencies")
66
- done[n].set()
67
- return
68
- try:
69
- name2op[n].func(name2op[n], ctx)
70
- except Exception:
71
- exc[n] = traceback.format_exc()
72
- done[n].set()
73
-
74
- ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo]
75
- for t in ts:
76
- t.start()
77
- for t in ts:
78
- t.join()
79
- if exc:
80
- raise RuntimeError(
81
- "Some operations failed:\n"
82
- + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- @staticmethod
86
- def _validate(ops: List[OpNode]):
87
- name_set = set()
88
- for op in ops:
89
- if op.name in name_set:
90
- raise ValueError(f"Duplicate operation name: {op.name}")
91
- name_set.add(op.name)
92
- for op in ops:
93
- for dep in op.deps:
94
- if dep not in name_set:
95
- raise ValueError(
96
- f"Operation {op.name} has unknown dependency: {dep}"
97
- )
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- def collect_ops(config: dict, graph_gen) -> List[OpNode]:
101
- """
102
- build operation nodes from yaml config
103
- :param config
104
- :param graph_gen
105
- """
106
- ops: List[OpNode] = []
107
- for stage in config["pipeline"]:
108
- name = stage["name"]
109
- method_name = stage.get("op_key")
110
- method = getattr(graph_gen, method_name)
111
- deps = stage.get("deps", [])
112
 
113
- if "params" in stage:
 
114
 
115
- def func(self, ctx, _method=method, _params=stage.get("params", {})):
116
- return _method(_params)
117
 
118
- else:
119
 
120
- def func(self, ctx, _method=method):
121
- return _method()
 
122
 
123
- op_node = OpNode(name=name, deps=deps, func=func)
124
- ops.append(op_node)
125
- return ops
 
1
+ import inspect
2
+ import logging
3
+ from collections import defaultdict, deque
4
+ from functools import wraps
5
+ from typing import Any, Callable, Dict, List, Set
6
 
7
+ import ray
8
+ import ray.data
 
9
 
10
+ from graphgen.bases import Config, Node
11
+ from graphgen.utils import logger
12
 
 
 
13
 
14
+ class Engine:
 
 
 
 
 
 
 
 
 
15
  def __init__(
16
+ self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
17
  ):
18
+ self.config = Config(**config)
19
+ self.global_params = self.config.global_params
20
+ self.functions = functions
21
+ self.datasets: Dict[str, ray.data.Dataset] = {}
22
+
23
+ if not ray.is_initialized():
24
+ context = ray.init(
25
+ ignore_reinit_error=True,
26
+ logging_level=logging.ERROR,
27
+ log_to_driver=True,
28
+ **ray_init_kwargs,
29
+ )
30
+ logger.info("Ray Dashboard URL: %s", context.dashboard_url)
31
 
32
+ @staticmethod
33
+ def _topo_sort(nodes: List[Node]) -> List[Node]:
34
+ id_to_node: Dict[str, Node] = {}
35
+ for n in nodes:
36
+ id_to_node[n.id] = n
37
+
38
+ indeg: Dict[str, int] = {nid: 0 for nid in id_to_node}
39
+ adj: Dict[str, List[str]] = defaultdict(list)
40
+
41
+ for n in nodes:
42
+ nid = n.id
43
+ deps: List[str] = n.dependencies
44
+ uniq_deps: Set[str] = set(deps)
45
+ for d in uniq_deps:
46
+ if d not in id_to_node:
47
+ raise ValueError(
48
+ f"The dependency node id {d} of node {nid} is not defined in the configuration."
49
+ )
50
+ indeg[nid] += 1
51
+ adj[d].append(nid)
52
+
53
+ zero_deg: deque = deque(
54
+ [id_to_node[nid] for nid, deg in indeg.items() if deg == 0]
55
+ )
56
+ sorted_nodes: List[Node] = []
57
+
58
+ while zero_deg:
59
+ cur = zero_deg.popleft()
60
+ sorted_nodes.append(cur)
61
+ cur_id = cur.id
62
+ for nb_id in adj.get(cur_id, []):
63
+ indeg[nb_id] -= 1
64
+ if indeg[nb_id] == 0:
65
+ zero_deg.append(id_to_node[nb_id])
66
+
67
+ if len(sorted_nodes) != len(nodes):
68
+ remaining = [nid for nid, deg in indeg.items() if deg > 0]
69
  raise ValueError(
70
+ f"The configuration contains cycles, unable to execute. Remaining nodes with indegree > 0: {remaining}"
 
71
  )
72
 
73
+ return sorted_nodes
74
+
75
+ def _get_input_dataset(
76
+ self, node: Node, initial_ds: ray.data.Dataset
77
+ ) -> ray.data.Dataset:
78
+ deps = node.dependencies
79
+
80
+ if not deps:
81
+ return initial_ds
82
+
83
+ if len(deps) == 1:
84
+ return self.datasets[deps[0]]
85
+
86
+ main_ds = self.datasets[deps[0]]
87
+ other_dss = [self.datasets[d] for d in deps[1:]]
88
+ return main_ds.union(*other_dss)
89
+
90
+ def _execute_node(self, node: Node, initial_ds: ray.data.Dataset):
91
+ def _filter_kwargs(
92
+ func_or_class: Callable,
93
+ global_params: Dict[str, Any],
94
+ func_params: Dict[str, Any],
95
+ ) -> Dict[str, Any]:
96
+ """
97
+ 1. global_params: only when specified in function signature, will be passed
98
+ 2. func_params: pass specified params first, then **kwargs if exists
99
+ """
100
+ try:
101
+ sig = inspect.signature(func_or_class)
102
+ except ValueError:
103
+ return {}
104
+
105
+ params = sig.parameters
106
+ final_kwargs = {}
107
+
108
+ has_var_keywords = any(
109
+ p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
110
+ )
111
+ valid_keys = set(params.keys())
112
+ for k, v in global_params.items():
113
+ if k in valid_keys:
114
+ final_kwargs[k] = v
115
+
116
+ for k, v in func_params.items():
117
+ if k in valid_keys or has_var_keywords:
118
+ final_kwargs[k] = v
119
+ return final_kwargs
120
+
121
+ if node.op_name not in self.functions:
122
+ raise ValueError(f"Operator {node.op_name} not found for node {node.id}")
123
+
124
+ op_handler = self.functions[node.op_name]
125
+ node_params = _filter_kwargs(op_handler, self.global_params, node.params or {})
126
+
127
+ if node.type == "source":
128
+ self.datasets[node.id] = op_handler(**node_params)
129
+ return
130
+
131
+ input_ds = self._get_input_dataset(node, initial_ds)
132
+
133
+ if inspect.isclass(op_handler):
134
+ execution_params = node.execution_params or {}
135
+ replicas = execution_params.get("replicas", 1)
136
+ batch_size = (
137
+ int(execution_params.get("batch_size"))
138
+ if "batch_size" in execution_params
139
+ else "default"
140
  )
141
+ compute_resources = execution_params.get("compute_resources", {})
142
+
143
+ if node.type == "aggregate":
144
+ self.datasets[node.id] = input_ds.repartition(1).map_batches(
145
+ op_handler,
146
+ compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1),
147
+ batch_size=None, # aggregate processes the whole dataset at once
148
+ num_gpus=compute_resources.get("num_gpus", 0)
149
+ if compute_resources
150
+ else 0,
151
+ fn_constructor_kwargs=node_params,
152
+ batch_format="pandas",
153
+ )
154
+ else:
155
+ # others like map, filter, flatmap, map_batch let actors process data inside batches
156
+ self.datasets[node.id] = input_ds.map_batches(
157
+ op_handler,
158
+ compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas),
159
+ batch_size=batch_size,
160
+ num_gpus=compute_resources.get("num_gpus", 0)
161
+ if compute_resources
162
+ else 0,
163
+ fn_constructor_kwargs=node_params,
164
+ batch_format="pandas",
165
+ )
166
 
167
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ @wraps(op_handler)
170
+ def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
171
+ return op_handler(row_or_batch, **node_params)
172
+
173
+ if node.type == "map":
174
+ self.datasets[node.id] = input_ds.map(func_wrapper)
175
+ elif node.type == "filter":
176
+ self.datasets[node.id] = input_ds.filter(func_wrapper)
177
+ elif node.type == "flatmap":
178
+ self.datasets[node.id] = input_ds.flat_map(func_wrapper)
179
+ elif node.type == "aggregate":
180
+ self.datasets[node.id] = input_ds.repartition(1).map_batches(
181
+ func_wrapper, batch_format="default"
182
+ )
183
+ elif node.type == "map_batch":
184
+ self.datasets[node.id] = input_ds.map_batches(func_wrapper)
185
+ else:
186
+ raise ValueError(
187
+ f"Unsupported node type {node.type} for node {node.id}"
188
+ )
189
 
190
+ @staticmethod
191
+ def _find_leaf_nodes(nodes: List[Node]) -> Set[str]:
192
+ all_ids = {n.id for n in nodes}
193
+ deps_set = set()
194
+ for n in nodes:
195
+ deps_set.update(n.dependencies)
196
+ return all_ids - deps_set
 
 
 
 
 
197
 
198
+ def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
199
+ sorted_nodes = self._topo_sort(self.config.nodes)
200
 
201
+ for node in sorted_nodes:
202
+ self._execute_node(node, initial_ds)
203
 
204
+ leaf_nodes = self._find_leaf_nodes(sorted_nodes)
205
 
206
+ @ray.remote
207
+ def _fetch_result(ds: ray.data.Dataset) -> List[Any]:
208
+ return ds.take_all()
209
 
210
+ return {node_id: self.datasets[node_id] for node_id in leaf_nodes}
 
 
graphgen/graphgen.py DELETED
@@ -1,295 +0,0 @@
1
- import os
2
- import time
3
- from typing import Dict
4
-
5
- import gradio as gr
6
-
7
- from graphgen.bases import BaseLLMWrapper
8
- from graphgen.bases.datatypes import Chunk
9
- from graphgen.models import (
10
- JsonKVStorage,
11
- JsonListStorage,
12
- NetworkXStorage,
13
- OpenAIClient,
14
- Tokenizer,
15
- )
16
- from graphgen.operators import (
17
- build_kg,
18
- chunk_documents,
19
- extract_info,
20
- generate_qas,
21
- init_llm,
22
- judge_statement,
23
- partition_kg,
24
- quiz,
25
- read_files,
26
- search_all,
27
- )
28
- from graphgen.utils import async_to_sync_method, compute_mm_hash, logger
29
-
30
- sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
31
-
32
-
33
- class GraphGen:
34
- def __init__(
35
- self,
36
- unique_id: int = int(time.time()),
37
- working_dir: str = os.path.join(sys_path, "cache"),
38
- tokenizer_instance: Tokenizer = None,
39
- synthesizer_llm_client: OpenAIClient = None,
40
- trainee_llm_client: OpenAIClient = None,
41
- progress_bar: gr.Progress = None,
42
- ):
43
- self.unique_id: int = unique_id
44
- self.working_dir: str = working_dir
45
-
46
- # llm
47
- self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer(
48
- model_name=os.getenv("TOKENIZER_MODEL", "cl100k_base")
49
- )
50
-
51
- self.synthesizer_llm_client: BaseLLMWrapper = (
52
- synthesizer_llm_client or init_llm("synthesizer")
53
- )
54
- self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client
55
-
56
- self.full_docs_storage: JsonKVStorage = JsonKVStorage(
57
- self.working_dir, namespace="full_docs"
58
- )
59
- self.chunks_storage: JsonKVStorage = JsonKVStorage(
60
- self.working_dir, namespace="chunks"
61
- )
62
- self.graph_storage: NetworkXStorage = NetworkXStorage(
63
- self.working_dir, namespace="graph"
64
- )
65
- self.rephrase_storage: JsonKVStorage = JsonKVStorage(
66
- self.working_dir, namespace="rephrase"
67
- )
68
- self.partition_storage: JsonListStorage = JsonListStorage(
69
- self.working_dir, namespace="partition"
70
- )
71
- self.search_storage: JsonKVStorage = JsonKVStorage(
72
- os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
73
- namespace="search",
74
- )
75
- self.qa_storage: JsonListStorage = JsonListStorage(
76
- os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
77
- namespace="qa",
78
- )
79
- self.extract_storage: JsonKVStorage = JsonKVStorage(
80
- os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
81
- namespace="extraction",
82
- )
83
-
84
- # webui
85
- self.progress_bar: gr.Progress = progress_bar
86
-
87
- @async_to_sync_method
88
- async def read(self, read_config: Dict):
89
- """
90
- read files from input sources
91
- """
92
- doc_stream = read_files(**read_config, cache_dir=self.working_dir)
93
-
94
- batch = {}
95
- for doc in doc_stream:
96
- doc_id = compute_mm_hash(doc, prefix="doc-")
97
- batch[doc_id] = doc
98
-
99
- # TODO: configurable whether to use coreference resolution
100
-
101
- _add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys()))
102
- new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys}
103
- if len(new_docs) == 0:
104
- logger.warning("All documents are already in the storage")
105
- return
106
- self.full_docs_storage.upsert(new_docs)
107
- self.full_docs_storage.index_done_callback()
108
-
109
- @async_to_sync_method
110
- async def chunk(self, chunk_config: Dict):
111
- """
112
- chunk documents into smaller pieces from full_docs_storage if not already present
113
- """
114
-
115
- new_docs = self.full_docs_storage.get_all()
116
- if len(new_docs) == 0:
117
- logger.warning("All documents are already in the storage")
118
- return
119
-
120
- inserting_chunks = await chunk_documents(
121
- new_docs,
122
- self.tokenizer_instance,
123
- self.progress_bar,
124
- **chunk_config,
125
- )
126
-
127
- _add_chunk_keys = self.chunks_storage.filter_keys(list(inserting_chunks.keys()))
128
- inserting_chunks = {
129
- k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
130
- }
131
-
132
- if len(inserting_chunks) == 0:
133
- logger.warning("All chunks are already in the storage")
134
- return
135
-
136
- self.chunks_storage.upsert(inserting_chunks)
137
- self.chunks_storage.index_done_callback()
138
-
139
- @async_to_sync_method
140
- async def build_kg(self):
141
- """
142
- build knowledge graph from text chunks
143
- """
144
- # Step 1: get new chunks
145
- inserting_chunks = self.chunks_storage.get_all()
146
-
147
- if len(inserting_chunks) == 0:
148
- logger.warning("All chunks are already in the storage")
149
- return
150
-
151
- logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
152
- # Step 2: build knowledge graph from new chunks
153
- _add_entities_and_relations = await build_kg(
154
- llm_client=self.synthesizer_llm_client,
155
- kg_instance=self.graph_storage,
156
- chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()],
157
- progress_bar=self.progress_bar,
158
- )
159
- if not _add_entities_and_relations:
160
- logger.warning("No entities or relations extracted from text chunks")
161
- return
162
-
163
- # Step 3: upsert new entities and relations to the graph storage
164
- self.graph_storage.index_done_callback()
165
-
166
- return _add_entities_and_relations
167
-
168
- @async_to_sync_method
169
- async def search(self, search_config: Dict):
170
- logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))
171
-
172
- seeds = self.full_docs_storage.get_all()
173
- if len(seeds) == 0:
174
- logger.warning("All documents are already been searched")
175
- return
176
- search_results = await search_all(
177
- seed_data=seeds,
178
- search_config=search_config,
179
- )
180
-
181
- _add_search_keys = self.search_storage.filter_keys(list(search_results.keys()))
182
- search_results = {
183
- k: v for k, v in search_results.items() if k in _add_search_keys
184
- }
185
- if len(search_results) == 0:
186
- logger.warning("All search results are already in the storage")
187
- return
188
- self.search_storage.upsert(search_results)
189
- self.search_storage.index_done_callback()
190
-
191
- @async_to_sync_method
192
- async def quiz_and_judge(self, quiz_and_judge_config: Dict):
193
- logger.warning(
194
- "Quiz and Judge operation needs trainee LLM client."
195
- " Make sure to provide one."
196
- )
197
- max_samples = quiz_and_judge_config["quiz_samples"]
198
- await quiz(
199
- self.synthesizer_llm_client,
200
- self.graph_storage,
201
- self.rephrase_storage,
202
- max_samples,
203
- progress_bar=self.progress_bar,
204
- )
205
-
206
- # TODO: assert trainee_llm_client is valid before judge
207
- if not self.trainee_llm_client:
208
- # TODO: shutdown existing synthesizer_llm_client properly
209
- logger.info("No trainee LLM client provided, initializing a new one.")
210
- self.synthesizer_llm_client.shutdown()
211
- self.trainee_llm_client = init_llm("trainee")
212
-
213
- re_judge = quiz_and_judge_config["re_judge"]
214
- _update_relations = await judge_statement(
215
- self.trainee_llm_client,
216
- self.graph_storage,
217
- self.rephrase_storage,
218
- re_judge,
219
- progress_bar=self.progress_bar,
220
- )
221
-
222
- self.rephrase_storage.index_done_callback()
223
- _update_relations.index_done_callback()
224
-
225
- logger.info("Shutting down trainee LLM client.")
226
- self.trainee_llm_client.shutdown()
227
- self.trainee_llm_client = None
228
- logger.info("Restarting synthesizer LLM client.")
229
- self.synthesizer_llm_client.restart()
230
-
231
- @async_to_sync_method
232
- async def partition(self, partition_config: Dict):
233
- batches = await partition_kg(
234
- self.graph_storage,
235
- self.chunks_storage,
236
- self.tokenizer_instance,
237
- partition_config,
238
- )
239
- self.partition_storage.upsert(batches)
240
- return batches
241
-
242
- @async_to_sync_method
243
- async def extract(self, extract_config: Dict):
244
- logger.info("Extracting information from given chunks...")
245
-
246
- results = await extract_info(
247
- self.synthesizer_llm_client,
248
- self.chunks_storage,
249
- extract_config,
250
- progress_bar=self.progress_bar,
251
- )
252
- if not results:
253
- logger.warning("No information extracted")
254
- return
255
-
256
- self.extract_storage.upsert(results)
257
- self.extract_storage.index_done_callback()
258
-
259
- @async_to_sync_method
260
- async def generate(self, generate_config: Dict):
261
-
262
- batches = self.partition_storage.data
263
- if not batches:
264
- logger.warning("No partitions found for QA generation")
265
- return
266
-
267
- # Step 2: generate QA pairs
268
- results = await generate_qas(
269
- self.synthesizer_llm_client,
270
- batches,
271
- generate_config,
272
- progress_bar=self.progress_bar,
273
- )
274
-
275
- if not results:
276
- logger.warning("No QA pairs generated")
277
- return
278
-
279
- # Step 3: store the generated QA pairs
280
- self.qa_storage.upsert(results)
281
- self.qa_storage.index_done_callback()
282
-
283
- @async_to_sync_method
284
- async def clear(self):
285
- self.full_docs_storage.drop()
286
- self.chunks_storage.drop()
287
- self.search_storage.drop()
288
- self.graph_storage.clear()
289
- self.rephrase_storage.drop()
290
- self.qa_storage.drop()
291
-
292
- logger.info("All caches are cleared")
293
-
294
- # TODO: add data filtering step here in the future
295
- # graph_gen.filter(filter_config=config["filter"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/__init__.py CHANGED
@@ -18,7 +18,6 @@ from .partitioner import (
18
  )
19
  from .reader import (
20
  CSVReader,
21
- JSONLReader,
22
  JSONReader,
23
  ParquetReader,
24
  PDFReader,
@@ -33,5 +32,11 @@ from .searcher.kg.wiki_search import WikiSearch
33
  from .searcher.web.bing_search import BingSearch
34
  from .searcher.web.google_search import GoogleSearch
35
  from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
36
- from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage, RocksDBCache
 
 
 
 
 
 
37
  from .tokenizer import Tokenizer
 
18
  )
19
  from .reader import (
20
  CSVReader,
 
21
  JSONReader,
22
  ParquetReader,
23
  PDFReader,
 
32
  from .searcher.web.bing_search import BingSearch
33
  from .searcher.web.google_search import GoogleSearch
34
  from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
35
+ from .storage import (
36
+ JsonKVStorage,
37
+ KuzuStorage,
38
+ NetworkXStorage,
39
+ RocksDBCache,
40
+ RocksDBKVStorage,
41
+ )
42
  from .tokenizer import Tokenizer
graphgen/models/extractor/schema_guided_extractor.py CHANGED
@@ -60,8 +60,8 @@ class SchemaGuidedExtractor(BaseExtractor):
60
  return prompt
61
 
62
  async def extract(self, chunk: dict) -> dict:
63
- _chunk_id = list(chunk.keys())[0]
64
- text = chunk[_chunk_id].get("content", "")
65
 
66
  prompt = self.build_prompt(text)
67
  response = await self.llm_client.generate_answer(prompt)
@@ -88,9 +88,7 @@ class SchemaGuidedExtractor(BaseExtractor):
88
  return {}
89
 
90
  @staticmethod
91
- async def merge_extractions(
92
- extraction_list: List[Dict[str, dict]]
93
- ) -> Dict[str, dict]:
94
  """
95
  Merge multiple extraction results based on their hashes.
96
  :param extraction_list: List of extraction results, each is a dict with hash as key and record as value.
 
60
  return prompt
61
 
62
  async def extract(self, chunk: dict) -> dict:
63
+ _chunk_id = chunk.get("_chunk_id", "")
64
+ text = chunk.get("content", "")
65
 
66
  prompt = self.build_prompt(text)
67
  response = await self.llm_client.generate_answer(prompt)
 
88
  return {}
89
 
90
  @staticmethod
91
+ def merge_extractions(extraction_list: List[Dict[str, dict]]) -> Dict[str, dict]:
 
 
92
  """
93
  Merge multiple extraction results based on their hashes.
94
  :param extraction_list: List of extraction results, each is a dict with hash as key and record as value.
graphgen/models/generator/vqa_generator.py CHANGED
@@ -77,8 +77,8 @@ class VQAGenerator(BaseGenerator):
77
  nodes, _ = batch
78
  for node in nodes:
79
  node_data = node[1]
80
- if "images" in node_data and node_data["images"]:
81
- img_path = node_data["images"]["img_path"]
82
  for qa in qa_pairs.values():
83
  qa["img_path"] = img_path
84
  result.update(qa_pairs)
 
77
  nodes, _ = batch
78
  for node in nodes:
79
  node_data = node[1]
80
+ if "image_data" in node_data and node_data["image_data"]:
81
+ img_path = node_data["image_data"]["img_path"]
82
  for qa in qa_pairs.values():
83
  qa["img_path"] = img_path
84
  result.update(qa_pairs)
graphgen/models/llm/local/sglang_wrapper.py CHANGED
@@ -138,15 +138,3 @@ class SGLangWrapper(BaseLLMWrapper):
138
  raise NotImplementedError(
139
  "SGLangWrapper does not support per-token logprobs yet."
140
  )
141
-
142
- def shutdown(self) -> None:
143
- """Gracefully shutdown the SGLang engine."""
144
- if hasattr(self, "engine"):
145
- self.engine.shutdown()
146
-
147
- def restart(self) -> None:
148
- """Restart the SGLang engine."""
149
- self.shutdown()
150
- self.engine = self.engine.__class__(
151
- model_path=self.model_path, tp_size=self.tp_size
152
- )
 
138
  raise NotImplementedError(
139
  "SGLangWrapper does not support per-token logprobs yet."
140
  )
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/llm/local/vllm_wrapper.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from typing import Any, List, Optional
2
 
3
  from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
@@ -6,7 +8,7 @@ from graphgen.bases.datatypes import Token
6
 
7
  class VLLMWrapper(BaseLLMWrapper):
8
  """
9
- Async inference backend based on vLLM (https://github.com/vllm-project/vllm)
10
  """
11
 
12
  def __init__(
@@ -20,12 +22,11 @@ class VLLMWrapper(BaseLLMWrapper):
20
  **kwargs: Any,
21
  ):
22
  super().__init__(temperature=temperature, top_p=top_p, **kwargs)
23
-
24
  try:
25
  from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
26
  except ImportError as exc:
27
  raise ImportError(
28
- "VLLMWrapper requires vllm. Install it with: uv pip install vllm --torch-backend=auto"
29
  ) from exc
30
 
31
  self.SamplingParams = SamplingParams
@@ -35,9 +36,9 @@ class VLLMWrapper(BaseLLMWrapper):
35
  tensor_parallel_size=tensor_parallel_size,
36
  gpu_memory_utilization=gpu_memory_utilization,
37
  trust_remote_code=kwargs.get("trust_remote_code", True),
 
38
  )
39
  self.engine = AsyncLLMEngine.from_engine_args(engine_args)
40
-
41
  self.temperature = temperature
42
  self.top_p = top_p
43
  self.topk = topk
@@ -60,6 +61,7 @@ class VLLMWrapper(BaseLLMWrapper):
60
  self, text: str, history: Optional[List[str]] = None, **extra: Any
61
  ) -> str:
62
  full_prompt = self._build_inputs(text, history)
 
63
 
64
  sp = self.SamplingParams(
65
  temperature=self.temperature if self.temperature > 0 else 1.0,
@@ -67,71 +69,57 @@ class VLLMWrapper(BaseLLMWrapper):
67
  max_tokens=extra.get("max_new_tokens", 512),
68
  )
69
 
70
- results = []
71
- async for req_output in self.engine.generate(
72
- full_prompt, sp, request_id="graphgen_req"
73
- ):
74
- results = req_output.outputs
75
- return results[-1].text
 
 
 
 
76
 
77
  async def generate_topk_per_token(
78
  self, text: str, history: Optional[List[str]] = None, **extra: Any
79
  ) -> List[Token]:
80
  full_prompt = self._build_inputs(text, history)
81
 
 
 
82
  sp = self.SamplingParams(
83
  temperature=0,
84
  max_tokens=1,
85
  logprobs=self.topk,
86
  )
87
 
88
- results = []
89
- async for req_output in self.engine.generate(
90
- full_prompt, sp, request_id="graphgen_topk"
 
 
 
 
 
 
 
91
  ):
92
- results = req_output.outputs
93
- top_logprobs = results[-1].logprobs[0]
 
94
 
95
  tokens = []
96
  for _, logprob_obj in top_logprobs.items():
97
  tok_str = logprob_obj.decoded_token
98
- prob = float(logprob_obj.logprob.exp())
99
  tokens.append(Token(tok_str, prob))
 
100
  tokens.sort(key=lambda x: -x.prob)
101
  return tokens
102
 
103
  async def generate_inputs_prob(
104
  self, text: str, history: Optional[List[str]] = None, **extra: Any
105
  ) -> List[Token]:
106
- full_prompt = self._build_inputs(text, history)
107
-
108
- # vLLM 没有现成的“mask 一个 token 再算 prob”接口,
109
- # 我们采用最直观的方式:把 prompt 一次性送进去,打开
110
- # prompt_logprobs=True,让 vLLM 返回 *输入部分* 每个位置的
111
- # logprob,然后挑出对应 token 的概率即可。
112
- sp = self.SamplingParams(
113
- temperature=0,
114
- max_tokens=0, # 不生成新 token
115
- prompt_logprobs=1, # 只要 top-1 就够了
116
  )
117
-
118
- results = []
119
- async for req_output in self.engine.generate(
120
- full_prompt, sp, request_id="graphgen_prob"
121
- ):
122
- results = req_output.outputs
123
-
124
- # prompt_logprobs 是一个 list,长度 = prompt token 数,
125
- # 每个元素是 dict{token_id: logprob_obj} 或 None(首个位置为 None)
126
- prompt_logprobs = results[-1].prompt_logprobs
127
-
128
- tokens = []
129
- for _, logprob_dict in enumerate(prompt_logprobs):
130
- if logprob_dict is None:
131
- continue
132
- # 这里每个 dict 只有 1 个 kv,因为 top-1
133
- _, logprob_obj = next(iter(logprob_dict.items()))
134
- tok_str = logprob_obj.decoded_token
135
- prob = float(logprob_obj.logprob.exp())
136
- tokens.append(Token(tok_str, prob))
137
- return tokens
 
1
+ import math
2
+ import uuid
3
  from typing import Any, List, Optional
4
 
5
  from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
 
8
 
9
  class VLLMWrapper(BaseLLMWrapper):
10
  """
11
+ Async inference backend based on vLLM.
12
  """
13
 
14
  def __init__(
 
22
  **kwargs: Any,
23
  ):
24
  super().__init__(temperature=temperature, top_p=top_p, **kwargs)
 
25
  try:
26
  from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
27
  except ImportError as exc:
28
  raise ImportError(
29
+ "VLLMWrapper requires vllm. Install it with: uv pip install vllm"
30
  ) from exc
31
 
32
  self.SamplingParams = SamplingParams
 
36
  tensor_parallel_size=tensor_parallel_size,
37
  gpu_memory_utilization=gpu_memory_utilization,
38
  trust_remote_code=kwargs.get("trust_remote_code", True),
39
+ disable_log_stats=False,
40
  )
41
  self.engine = AsyncLLMEngine.from_engine_args(engine_args)
 
42
  self.temperature = temperature
43
  self.top_p = top_p
44
  self.topk = topk
 
61
  self, text: str, history: Optional[List[str]] = None, **extra: Any
62
  ) -> str:
63
  full_prompt = self._build_inputs(text, history)
64
+ request_id = f"graphgen_req_{uuid.uuid4()}"
65
 
66
  sp = self.SamplingParams(
67
  temperature=self.temperature if self.temperature > 0 else 1.0,
 
69
  max_tokens=extra.get("max_new_tokens", 512),
70
  )
71
 
72
+ result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
73
+
74
+ final_output = None
75
+ async for request_output in result_generator:
76
+ final_output = request_output
77
+
78
+ if not final_output or not final_output.outputs:
79
+ return ""
80
+
81
+ return final_output.outputs[0].text
82
 
83
  async def generate_topk_per_token(
84
  self, text: str, history: Optional[List[str]] = None, **extra: Any
85
  ) -> List[Token]:
86
  full_prompt = self._build_inputs(text, history)
87
 
88
+ request_id = f"graphgen_topk_{uuid.uuid4()}"
89
+
90
  sp = self.SamplingParams(
91
  temperature=0,
92
  max_tokens=1,
93
  logprobs=self.topk,
94
  )
95
 
96
+ result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
97
+
98
+ final_output = None
99
+ async for request_output in result_generator:
100
+ final_output = request_output
101
+
102
+ if (
103
+ not final_output
104
+ or not final_output.outputs
105
+ or not final_output.outputs[0].logprobs
106
  ):
107
+ return []
108
+
109
+ top_logprobs = final_output.outputs[0].logprobs[0]
110
 
111
  tokens = []
112
  for _, logprob_obj in top_logprobs.items():
113
  tok_str = logprob_obj.decoded_token
114
+ prob = float(math.exp(logprob_obj.logprob))
115
  tokens.append(Token(tok_str, prob))
116
+
117
  tokens.sort(key=lambda x: -x.prob)
118
  return tokens
119
 
120
  async def generate_inputs_prob(
121
  self, text: str, history: Optional[List[str]] = None, **extra: Any
122
  ) -> List[Token]:
123
+ raise NotImplementedError(
124
+ "VLLMWrapper does not support per-token logprobs yet."
 
 
 
 
 
 
 
 
125
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/partitioner/anchor_bfs_partitioner.py CHANGED
@@ -1,6 +1,6 @@
1
  import random
2
  from collections import deque
3
- from typing import Any, List, Literal, Set, Tuple
4
 
5
  from graphgen.bases import BaseGraphStorage
6
  from graphgen.bases.datatypes import Community
@@ -30,24 +30,23 @@ class AnchorBFSPartitioner(BFSPartitioner):
30
  self.anchor_type = anchor_type
31
  self.anchor_ids = anchor_ids
32
 
33
- async def partition(
34
  self,
35
  g: BaseGraphStorage,
36
  max_units_per_community: int = 1,
37
  **kwargs: Any,
38
- ) -> List[Community]:
39
  nodes = g.get_all_nodes() # List[tuple[id, meta]]
40
  edges = g.get_all_edges() # List[tuple[u, v, meta]]
41
 
42
  adj, _ = self._build_adjacency_list(nodes, edges)
43
 
44
- anchors: Set[str] = await self._pick_anchor_ids(nodes)
45
  if not anchors:
46
- return [] # if no anchors, return empty list
47
 
48
  used_n: set[str] = set()
49
  used_e: set[frozenset[str]] = set()
50
- communities: List[Community] = []
51
 
52
  seeds = list(anchors)
53
  random.shuffle(seeds)
@@ -55,17 +54,13 @@ class AnchorBFSPartitioner(BFSPartitioner):
55
  for seed_node in seeds:
56
  if seed_node in used_n:
57
  continue
58
- comm_n, comm_e = await self._grow_community(
59
  seed_node, adj, max_units_per_community, used_n, used_e
60
  )
61
  if comm_n or comm_e:
62
- communities.append(
63
- Community(id=len(communities), nodes=comm_n, edges=comm_e)
64
- )
65
 
66
- return communities
67
-
68
- async def _pick_anchor_ids(
69
  self,
70
  nodes: List[tuple[str, dict]],
71
  ) -> Set[str]:
@@ -80,7 +75,7 @@ class AnchorBFSPartitioner(BFSPartitioner):
80
  return anchor_ids
81
 
82
  @staticmethod
83
- async def _grow_community(
84
  seed: str,
85
  adj: dict[str, List[str]],
86
  max_units: int,
 
1
  import random
2
  from collections import deque
3
+ from typing import Any, Iterable, List, Literal, Set, Tuple
4
 
5
  from graphgen.bases import BaseGraphStorage
6
  from graphgen.bases.datatypes import Community
 
30
  self.anchor_type = anchor_type
31
  self.anchor_ids = anchor_ids
32
 
33
+ def partition(
34
  self,
35
  g: BaseGraphStorage,
36
  max_units_per_community: int = 1,
37
  **kwargs: Any,
38
+ ) -> Iterable[Community]:
39
  nodes = g.get_all_nodes() # List[tuple[id, meta]]
40
  edges = g.get_all_edges() # List[tuple[u, v, meta]]
41
 
42
  adj, _ = self._build_adjacency_list(nodes, edges)
43
 
44
+ anchors: Set[str] = self._pick_anchor_ids(nodes)
45
  if not anchors:
46
+ return # if no anchors, return nothing
47
 
48
  used_n: set[str] = set()
49
  used_e: set[frozenset[str]] = set()
 
50
 
51
  seeds = list(anchors)
52
  random.shuffle(seeds)
 
54
  for seed_node in seeds:
55
  if seed_node in used_n:
56
  continue
57
+ comm_n, comm_e = self._grow_community(
58
  seed_node, adj, max_units_per_community, used_n, used_e
59
  )
60
  if comm_n or comm_e:
61
+ yield Community(id=seed_node, nodes=comm_n, edges=comm_e)
 
 
62
 
63
+ def _pick_anchor_ids(
 
 
64
  self,
65
  nodes: List[tuple[str, dict]],
66
  ) -> Set[str]:
 
75
  return anchor_ids
76
 
77
  @staticmethod
78
+ def _grow_community(
79
  seed: str,
80
  adj: dict[str, List[str]],
81
  max_units: int,
graphgen/models/partitioner/bfs_partitioner.py CHANGED
@@ -1,6 +1,6 @@
1
  import random
2
  from collections import deque
3
- from typing import Any, List
4
 
5
  from graphgen.bases import BaseGraphStorage, BasePartitioner
6
  from graphgen.bases.datatypes import Community
@@ -17,12 +17,12 @@ class BFSPartitioner(BasePartitioner):
17
  (A unit is a node or an edge.)
18
  """
19
 
20
- async def partition(
21
  self,
22
  g: BaseGraphStorage,
23
  max_units_per_community: int = 1,
24
  **kwargs: Any,
25
- ) -> List[Community]:
26
  nodes = g.get_all_nodes()
27
  edges = g.get_all_edges()
28
 
@@ -30,7 +30,6 @@ class BFSPartitioner(BasePartitioner):
30
 
31
  used_n: set[str] = set()
32
  used_e: set[frozenset[str]] = set()
33
- communities: List[Community] = []
34
 
35
  units = [(NODE_UNIT, n[0]) for n in nodes] + [
36
  (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
@@ -74,8 +73,4 @@ class BFSPartitioner(BasePartitioner):
74
  queue.append((NODE_UNIT, n))
75
 
76
  if comm_n or comm_e:
77
- communities.append(
78
- Community(id=len(communities), nodes=comm_n, edges=comm_e)
79
- )
80
-
81
- return communities
 
1
  import random
2
  from collections import deque
3
+ from typing import Any, Iterable, List
4
 
5
  from graphgen.bases import BaseGraphStorage, BasePartitioner
6
  from graphgen.bases.datatypes import Community
 
17
  (A unit is a node or an edge.)
18
  """
19
 
20
+ def partition(
21
  self,
22
  g: BaseGraphStorage,
23
  max_units_per_community: int = 1,
24
  **kwargs: Any,
25
+ ) -> Iterable[Community]:
26
  nodes = g.get_all_nodes()
27
  edges = g.get_all_edges()
28
 
 
30
 
31
  used_n: set[str] = set()
32
  used_e: set[frozenset[str]] = set()
 
33
 
34
  units = [(NODE_UNIT, n[0]) for n in nodes] + [
35
  (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
 
73
  queue.append((NODE_UNIT, n))
74
 
75
  if comm_n or comm_e:
76
+ yield Community(id=seed, nodes=comm_n, edges=comm_e)
 
 
 
 
graphgen/models/partitioner/dfs_partitioner.py CHANGED
@@ -1,5 +1,6 @@
1
  import random
2
- from typing import Any, List
 
3
 
4
  from graphgen.bases import BaseGraphStorage, BasePartitioner
5
  from graphgen.bases.datatypes import Community
@@ -16,12 +17,12 @@ class DFSPartitioner(BasePartitioner):
16
  (In GraphGen, a unit is defined as a node or an edge.)
17
  """
18
 
19
- async def partition(
20
  self,
21
  g: BaseGraphStorage,
22
  max_units_per_community: int = 1,
23
  **kwargs: Any,
24
- ) -> List[Community]:
25
  nodes = g.get_all_nodes()
26
  edges = g.get_all_edges()
27
 
@@ -29,7 +30,6 @@ class DFSPartitioner(BasePartitioner):
29
 
30
  used_n: set[str] = set()
31
  used_e: set[frozenset[str]] = set()
32
- communities: List[Community] = []
33
 
34
  units = [(NODE_UNIT, n[0]) for n in nodes] + [
35
  (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
@@ -71,8 +71,4 @@ class DFSPartitioner(BasePartitioner):
71
  stack.append((NODE_UNIT, n))
72
 
73
  if comm_n or comm_e:
74
- communities.append(
75
- Community(id=len(communities), nodes=comm_n, edges=comm_e)
76
- )
77
-
78
- return communities
 
1
  import random
2
+ from collections.abc import Iterable
3
+ from typing import Any
4
 
5
  from graphgen.bases import BaseGraphStorage, BasePartitioner
6
  from graphgen.bases.datatypes import Community
 
17
  (In GraphGen, a unit is defined as a node or an edge.)
18
  """
19
 
20
+ def partition(
21
  self,
22
  g: BaseGraphStorage,
23
  max_units_per_community: int = 1,
24
  **kwargs: Any,
25
+ ) -> Iterable[Community]:
26
  nodes = g.get_all_nodes()
27
  edges = g.get_all_edges()
28
 
 
30
 
31
  used_n: set[str] = set()
32
  used_e: set[frozenset[str]] = set()
 
33
 
34
  units = [(NODE_UNIT, n[0]) for n in nodes] + [
35
  (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
 
71
  stack.append((NODE_UNIT, n))
72
 
73
  if comm_n or comm_e:
74
+ yield Community(id=seed, nodes=comm_n, edges=comm_e)
 
 
 
 
graphgen/models/partitioner/ece_partitioner.py CHANGED
@@ -1,8 +1,8 @@
1
- import asyncio
2
  import random
3
- from typing import Any, Dict, List, Optional, Set, Tuple
 
4
 
5
- from tqdm.asyncio import tqdm as tqdm_async
6
 
7
  from graphgen.bases import BaseGraphStorage
8
  from graphgen.bases.datatypes import Community
@@ -51,7 +51,7 @@ class ECEPartitioner(BFSPartitioner):
51
  raise ValueError(f"Invalid edge sampling: {edge_sampling}")
52
  return units
53
 
54
- async def partition(
55
  self,
56
  g: BaseGraphStorage,
57
  max_units_per_community: int = 10,
@@ -59,7 +59,7 @@ class ECEPartitioner(BFSPartitioner):
59
  max_tokens_per_community: int = 10240,
60
  unit_sampling: str = "random",
61
  **kwargs: Any,
62
- ) -> List[Community]:
63
  nodes: List[Tuple[str, dict]] = g.get_all_nodes()
64
  edges: List[Tuple[str, str, dict]] = g.get_all_edges()
65
 
@@ -73,21 +73,18 @@ class ECEPartitioner(BFSPartitioner):
73
 
74
  used_n: Set[str] = set()
75
  used_e: Set[frozenset[str]] = set()
76
- communities: List = []
77
 
78
  all_units = self._sort_units(all_units, unit_sampling)
79
 
80
- async def _grow_community(
81
- seed_unit: Tuple[str, Any, dict]
82
- ) -> Optional[Community]:
83
  nonlocal used_n, used_e
84
 
85
  community_nodes: Dict[str, dict] = {}
86
  community_edges: Dict[frozenset[str], dict] = {}
87
- queue: asyncio.Queue = asyncio.Queue()
88
  token_sum = 0
89
 
90
- async def _add_unit(u):
91
  nonlocal token_sum
92
  t, i, d = u
93
  if t == NODE_UNIT: # node
@@ -103,11 +100,11 @@ class ECEPartitioner(BFSPartitioner):
103
  token_sum += d.get("length", 0)
104
  return True
105
 
106
- await _add_unit(seed_unit)
107
- await queue.put(seed_unit)
108
 
109
  # BFS
110
- while not queue.empty():
111
  if (
112
  len(community_nodes) + len(community_edges)
113
  >= max_units_per_community
@@ -115,7 +112,7 @@ class ECEPartitioner(BFSPartitioner):
115
  ):
116
  break
117
 
118
- cur_type, cur_id, _ = await queue.get()
119
 
120
  neighbors: List[Tuple[str, Any, dict]] = []
121
  if cur_type == NODE_UNIT:
@@ -136,26 +133,24 @@ class ECEPartitioner(BFSPartitioner):
136
  or token_sum >= max_tokens_per_community
137
  ):
138
  break
139
- if await _add_unit(nb):
140
- await queue.put(nb)
141
 
142
  if len(community_nodes) + len(community_edges) < min_units_per_community:
143
  return None
144
 
145
  return Community(
146
- id=len(communities),
147
  nodes=list(community_nodes.keys()),
148
  edges=[(u, v) for (u, v), _ in community_edges.items()],
149
  )
150
 
151
- async for unit in tqdm_async(all_units, desc="ECE partition"):
152
  utype, uid, _ = unit
153
  if (utype == NODE_UNIT and uid in used_n) or (
154
  utype == EDGE_UNIT and uid in used_e
155
  ):
156
  continue
157
- comm = await _grow_community(unit)
158
- if comm is not None:
159
- communities.append(comm)
160
-
161
- return communities
 
 
1
  import random
2
+ from collections import deque
3
+ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
4
 
5
+ from tqdm import tqdm
6
 
7
  from graphgen.bases import BaseGraphStorage
8
  from graphgen.bases.datatypes import Community
 
51
  raise ValueError(f"Invalid edge sampling: {edge_sampling}")
52
  return units
53
 
54
+ def partition(
55
  self,
56
  g: BaseGraphStorage,
57
  max_units_per_community: int = 10,
 
59
  max_tokens_per_community: int = 10240,
60
  unit_sampling: str = "random",
61
  **kwargs: Any,
62
+ ) -> Iterable[Community]:
63
  nodes: List[Tuple[str, dict]] = g.get_all_nodes()
64
  edges: List[Tuple[str, str, dict]] = g.get_all_edges()
65
 
 
73
 
74
  used_n: Set[str] = set()
75
  used_e: Set[frozenset[str]] = set()
 
76
 
77
  all_units = self._sort_units(all_units, unit_sampling)
78
 
79
+ def _grow_community(seed_unit: Tuple[str, Any, dict]) -> Optional[Community]:
 
 
80
  nonlocal used_n, used_e
81
 
82
  community_nodes: Dict[str, dict] = {}
83
  community_edges: Dict[frozenset[str], dict] = {}
84
+ queue = deque()
85
  token_sum = 0
86
 
87
+ def _add_unit(u):
88
  nonlocal token_sum
89
  t, i, d = u
90
  if t == NODE_UNIT: # node
 
100
  token_sum += d.get("length", 0)
101
  return True
102
 
103
+ _add_unit(seed_unit)
104
+ queue.append(seed_unit)
105
 
106
  # BFS
107
+ while queue:
108
  if (
109
  len(community_nodes) + len(community_edges)
110
  >= max_units_per_community
 
112
  ):
113
  break
114
 
115
+ cur_type, cur_id, _ = queue.popleft()
116
 
117
  neighbors: List[Tuple[str, Any, dict]] = []
118
  if cur_type == NODE_UNIT:
 
133
  or token_sum >= max_tokens_per_community
134
  ):
135
  break
136
+ if _add_unit(nb):
137
+ queue.append(nb)
138
 
139
  if len(community_nodes) + len(community_edges) < min_units_per_community:
140
  return None
141
 
142
  return Community(
143
+ id=seed_unit[1],
144
  nodes=list(community_nodes.keys()),
145
  edges=[(u, v) for (u, v), _ in community_edges.items()],
146
  )
147
 
148
+ for unit in tqdm(all_units, desc="ECE partition"):
149
  utype, uid, _ = unit
150
  if (utype == NODE_UNIT and uid in used_n) or (
151
  utype == EDGE_UNIT and uid in used_e
152
  ):
153
  continue
154
+ comm = _grow_community(unit)
155
+ if comm:
156
+ yield comm
 
 
graphgen/models/partitioner/leiden_partitioner.py CHANGED
@@ -13,7 +13,7 @@ class LeidenPartitioner(BasePartitioner):
13
  Leiden partitioner that partitions the graph into communities using the Leiden algorithm.
14
  """
15
 
16
- async def partition(
17
  self,
18
  g: BaseGraphStorage,
19
  max_size: int = 20,
@@ -37,12 +37,10 @@ class LeidenPartitioner(BasePartitioner):
37
  nodes = g.get_all_nodes() # List[Tuple[str, dict]]
38
  edges = g.get_all_edges() # List[Tuple[str, str, dict]]
39
 
40
- node2cid: Dict[str, int] = await self._run_leiden(
41
- nodes, edges, use_lcc, random_seed
42
- )
43
 
44
  if max_size is not None and max_size > 0:
45
- node2cid = await self._split_communities(node2cid, max_size)
46
 
47
  cid2nodes: Dict[int, List[str]] = defaultdict(list)
48
  for n, cid in node2cid.items():
@@ -58,7 +56,7 @@ class LeidenPartitioner(BasePartitioner):
58
  return communities
59
 
60
  @staticmethod
61
- async def _run_leiden(
62
  nodes: List[Tuple[str, dict]],
63
  edges: List[Tuple[str, str, dict]],
64
  use_lcc: bool = False,
@@ -92,9 +90,7 @@ class LeidenPartitioner(BasePartitioner):
92
  return node2cid
93
 
94
  @staticmethod
95
- async def _split_communities(
96
- node2cid: Dict[str, int], max_size: int
97
- ) -> Dict[str, int]:
98
  """
99
  Split communities larger than max_size into smaller sub-communities.
100
  """
 
13
  Leiden partitioner that partitions the graph into communities using the Leiden algorithm.
14
  """
15
 
16
+ def partition(
17
  self,
18
  g: BaseGraphStorage,
19
  max_size: int = 20,
 
37
  nodes = g.get_all_nodes() # List[Tuple[str, dict]]
38
  edges = g.get_all_edges() # List[Tuple[str, str, dict]]
39
 
40
+ node2cid: Dict[str, int] = self._run_leiden(nodes, edges, use_lcc, random_seed)
 
 
41
 
42
  if max_size is not None and max_size > 0:
43
+ node2cid = self._split_communities(node2cid, max_size)
44
 
45
  cid2nodes: Dict[int, List[str]] = defaultdict(list)
46
  for n, cid in node2cid.items():
 
56
  return communities
57
 
58
  @staticmethod
59
+ def _run_leiden(
60
  nodes: List[Tuple[str, dict]],
61
  edges: List[Tuple[str, str, dict]],
62
  use_lcc: bool = False,
 
90
  return node2cid
91
 
92
  @staticmethod
93
+ def _split_communities(node2cid: Dict[str, int], max_size: int) -> Dict[str, int]:
 
 
94
  """
95
  Split communities larger than max_size into smaller sub-communities.
96
  """
graphgen/models/reader/__init__.py CHANGED
@@ -1,6 +1,5 @@
1
  from .csv_reader import CSVReader
2
  from .json_reader import JSONReader
3
- from .jsonl_reader import JSONLReader
4
  from .parquet_reader import ParquetReader
5
  from .pdf_reader import PDFReader
6
  from .pickle_reader import PickleReader
 
1
  from .csv_reader import CSVReader
2
  from .json_reader import JSONReader
 
3
  from .parquet_reader import ParquetReader
4
  from .pdf_reader import PDFReader
5
  from .pickle_reader import PickleReader
graphgen/models/reader/csv_reader.py CHANGED
@@ -1,6 +1,7 @@
1
- from typing import Any, Dict, List
2
 
3
- import pandas as pd
 
4
 
5
  from graphgen.bases.base_reader import BaseReader
6
 
@@ -13,13 +14,15 @@ class CSVReader(BaseReader):
13
  - if type is "text", "content" column must be present.
14
  """
15
 
16
- def read(self, file_path: str) -> List[Dict[str, Any]]:
 
 
17
 
18
- df = pd.read_csv(file_path)
19
- for _, row in df.iterrows():
20
- assert "type" in row, f"Missing 'type' column in document: {row.to_dict()}"
21
- if row["type"] == "text" and self.text_column not in row:
22
- raise ValueError(
23
- f"Missing '{self.text_column}' in document: {row.to_dict()}"
24
- )
25
- return self.filter(df.to_dict(orient="records"))
 
1
+ from typing import List, Union
2
 
3
+ import ray
4
+ from ray.data import Dataset
5
 
6
  from graphgen.bases.base_reader import BaseReader
7
 
 
14
  - if type is "text", "content" column must be present.
15
  """
16
 
17
+ def read(self, input_path: Union[str, List[str]]) -> Dataset:
18
+ """
19
+ Read CSV files and return Ray Dataset.
20
 
21
+ :param input_path: Path to CSV file or list of CSV files.
22
+ :return: Ray Dataset containing validated and filtered data.
23
+ """
24
+
25
+ ds = ray.data.read_csv(input_path)
26
+ ds = ds.map_batches(self._validate_batch, batch_format="pandas")
27
+ ds = ds.filter(self._should_keep_item)
28
+ return ds
graphgen/models/reader/json_reader.py CHANGED
@@ -1,26 +1,53 @@
1
  import json
2
- from typing import Any, Dict, List
 
 
 
3
 
4
  from graphgen.bases.base_reader import BaseReader
5
 
6
 
7
  class JSONReader(BaseReader):
8
  """
9
- Reader for JSON files.
10
  Columns:
11
  - type: The type of the document (e.g., "text", "image", etc.)
12
  - if type is "text", "content" column must be present.
13
  """
14
 
15
- def read(self, file_path: str) -> List[Dict[str, Any]]:
16
- with open(file_path, "r", encoding="utf-8") as f:
17
- data = json.load(f)
18
- if isinstance(data, list):
19
- for doc in data:
20
- assert "type" in doc, f"Missing 'type' in document: {doc}"
21
- if doc.get("type") == "text" and self.text_column not in doc:
22
- raise ValueError(
23
- f"Missing '{self.text_column}' in document: {doc}"
24
- )
25
- return self.filter(data)
26
- raise ValueError("JSON file must contain a list of documents.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
+ from typing import List, Union
3
+
4
+ import ray
5
+ import ray.data
6
 
7
  from graphgen.bases.base_reader import BaseReader
8
 
9
 
10
  class JSONReader(BaseReader):
11
  """
12
+ Reader for JSON and JSONL files.
13
  Columns:
14
  - type: The type of the document (e.g., "text", "image", etc.)
15
  - if type is "text", "content" column must be present.
16
  """
17
 
18
+ def read(self, input_path: Union[str, List[str]]) -> ray.data.Dataset:
19
+ """
20
+ Read JSON file and return Ray Dataset.
21
+ :param input_path: Path to JSON/JSONL file or list of JSON/JSONL files.
22
+ :return: Ray Dataset containing validated and filtered data.
23
+ """
24
+ if self.modalities and len(self.modalities) >= 2:
25
+ ds: ray.data.Dataset = ray.data.from_items([])
26
+ for file in input_path if isinstance(input_path, list) else [input_path]:
27
+ data = []
28
+ if file.endswith(".jsonl"):
29
+ with open(file, "r", encoding="utf-8") as f:
30
+ for line in f:
31
+ item = json.loads(line)
32
+ data.append(item)
33
+ else:
34
+ with open(file, "r", encoding="utf-8") as f:
35
+ data = json.load(f)
36
+ data = self._unify_schema(data)
37
+ file_ds: ray.data.Dataset = ray.data.from_items(data)
38
+ ds = ds.union(file_ds) # type: ignore
39
+ else:
40
+ ds = ray.data.read_json(input_path)
41
+ ds = ds.map_batches(self._validate_batch, batch_format="pandas")
42
+ ds = ds.filter(self._should_keep_item)
43
+ return ds
44
+
45
+ @staticmethod
46
+ def _unify_schema(data):
47
+ """
48
+ Unify schema for JSON data.
49
+ """
50
+ for item in data:
51
+ if "content" in item and isinstance(item["content"], dict):
52
+ item["content"] = json.dumps(item["content"])
53
+ return data
graphgen/models/reader/jsonl_reader.py DELETED
@@ -1,30 +0,0 @@
1
- import json
2
- from typing import Any, Dict, List
3
-
4
- from graphgen.bases.base_reader import BaseReader
5
- from graphgen.utils import logger
6
-
7
-
8
- class JSONLReader(BaseReader):
9
- """
10
- Reader for JSONL files.
11
- Columns:
12
- - type: The type of the document (e.g., "text", "image", etc.)
13
- - if type is "text", "content" column must be present.
14
- """
15
-
16
- def read(self, file_path: str) -> List[Dict[str, Any]]:
17
- docs = []
18
- with open(file_path, "r", encoding="utf-8") as f:
19
- for line in f:
20
- try:
21
- doc = json.loads(line)
22
- assert "type" in doc, f"Missing 'type' in document: {doc}"
23
- if doc.get("type") == "text" and self.text_column not in doc:
24
- raise ValueError(
25
- f"Missing '{self.text_column}' in document: {doc}"
26
- )
27
- docs.append(doc)
28
- except json.JSONDecodeError as e:
29
- logger.error("Error decoding JSON line: %s. Error: %s", line, e)
30
- return self.filter(docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
graphgen/models/reader/parquet_reader.py CHANGED
@@ -1,6 +1,7 @@
1
- from typing import Any, Dict, List
2
 
3
- import pandas as pd
 
4
 
5
  from graphgen.bases.base_reader import BaseReader
6
 
@@ -13,12 +14,17 @@ class ParquetReader(BaseReader):
13
  - if type is "text", "content" column must be present.
14
  """
15
 
16
- def read(self, file_path: str) -> List[Dict[str, Any]]:
17
- df = pd.read_parquet(file_path)
18
- data: List[Dict[str, Any]] = df.to_dict(orient="records")
19
 
20
- for doc in data:
21
- assert "type" in doc, f"Missing 'type' in document: {doc}"
22
- if doc.get("type") == "text" and self.text_column not in doc:
23
- raise ValueError(f"Missing '{self.text_column}' in document: {doc}")
24
- return self.filter(data)
 
 
 
 
 
 
1
+ from typing import List, Union
2
 
3
+ import ray
4
+ from ray.data import Dataset
5
 
6
  from graphgen.bases.base_reader import BaseReader
7
 
 
14
  - if type is "text", "content" column must be present.
15
  """
16
 
17
+ def read(self, input_path: Union[str, List[str]]) -> Dataset:
18
+ """
19
+ Read Parquet files using Ray Data.
20
 
21
+ :param input_path: Path to Parquet file or list of Parquet files.
22
+ :return: Ray Dataset containing validated documents.
23
+ """
24
+ if not ray.is_initialized():
25
+ ray.init()
26
+
27
+ ds = ray.data.read_parquet(input_path)
28
+ ds = ds.map_batches(self._validate_batch, batch_format="pandas")
29
+ ds = ds.filter(self._should_keep_item)
30
+ return ds
graphgen/models/reader/pdf_reader.py CHANGED
@@ -5,6 +5,9 @@ import tempfile
5
  from pathlib import Path
6
  from typing import Any, Dict, List, Optional, Union
7
 
 
 
 
8
  from graphgen.bases.base_reader import BaseReader
9
  from graphgen.models.reader.txt_reader import TXTReader
10
  from graphgen.utils import logger, pick_device
@@ -62,19 +65,31 @@ class PDFReader(BaseReader):
62
  self.parser = MinerUParser()
63
  self.txt_reader = TXTReader()
64
 
65
- def read(self, file_path: str, **override) -> List[Dict[str, Any]]:
66
- """
67
- file_path
68
- **override: override MinerU parameters
69
- """
70
- pdf_path = Path(file_path).expanduser().resolve()
71
- if not pdf_path.is_file():
72
- raise FileNotFoundError(pdf_path)
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- kwargs = {**self._default_kwargs, **override}
 
75
 
76
- mineru_result = self._call_mineru(pdf_path, kwargs)
77
- return self.filter(mineru_result)
78
 
79
  def _call_mineru(
80
  self, pdf_path: Path, kwargs: Dict[str, Any]
@@ -161,18 +176,18 @@ class MinerUParser:
161
 
162
  base = os.path.dirname(json_file)
163
  results = []
164
- for item in data:
165
  for key in ("img_path", "table_img_path", "equation_img_path"):
166
- rel_path = item.get(key)
167
  if rel_path:
168
- item[key] = str(Path(base).joinpath(rel_path).resolve())
169
- if item["type"] == "text":
170
- item["content"] = item["text"]
171
- del item["text"]
172
  for key in ("page_idx", "bbox", "text_level"):
173
- if item.get(key) is not None:
174
- del item[key]
175
- results.append(item)
176
  return results
177
 
178
  @staticmethod
 
5
  from pathlib import Path
6
  from typing import Any, Dict, List, Optional, Union
7
 
8
+ import ray
9
+ from ray.data import Dataset
10
+
11
  from graphgen.bases.base_reader import BaseReader
12
  from graphgen.models.reader.txt_reader import TXTReader
13
  from graphgen.utils import logger, pick_device
 
65
  self.parser = MinerUParser()
66
  self.txt_reader = TXTReader()
67
 
68
+ def read(
69
+ self,
70
+ input_path: Union[str, List[str]],
71
+ **override,
72
+ ) -> Dataset:
73
+
74
+ # Ensure input_path is a list
75
+ if isinstance(input_path, str):
76
+ input_path = [input_path]
77
+
78
+ paths_ds = ray.data.from_items(input_path)
79
+
80
+ def process_pdf(row: Dict[str, Any]) -> List[Dict[str, Any]]:
81
+ try:
82
+ pdf_path = row["item"]
83
+ kwargs = {**self._default_kwargs, **override}
84
+ return self._call_mineru(Path(pdf_path), kwargs)
85
+ except Exception as e:
86
+ logger.error("Failed to process %s: %s", row, e)
87
+ return []
88
 
89
+ docs_ds = paths_ds.flat_map(process_pdf)
90
+ docs_ds = docs_ds.filter(self._should_keep_item)
91
 
92
+ return docs_ds
 
93
 
94
  def _call_mineru(
95
  self, pdf_path: Path, kwargs: Dict[str, Any]
 
176
 
177
  base = os.path.dirname(json_file)
178
  results = []
179
+ for it in data:
180
  for key in ("img_path", "table_img_path", "equation_img_path"):
181
+ rel_path = it.get(key)
182
  if rel_path:
183
+ it[key] = str(Path(base).joinpath(rel_path).resolve())
184
+ if it["type"] == "text":
185
+ it["content"] = it["text"]
186
+ del it["text"]
187
  for key in ("page_idx", "bbox", "text_level"):
188
+ if it.get(key) is not None:
189
+ del it[key]
190
+ results.append(it)
191
  return results
192
 
193
  @staticmethod
graphgen/models/reader/pickle_reader.py CHANGED
@@ -1,30 +1,78 @@
1
  import pickle
2
- from typing import Any, Dict, List
 
 
 
 
3
 
4
  from graphgen.bases.base_reader import BaseReader
 
5
 
6
 
7
  class PickleReader(BaseReader):
8
  """
9
- Read pickle files, requiring the top-level object to be List[Dict[str, Any]].
10
-
11
- Columns:
12
  - type: The type of the document (e.g., "text", "image", etc.)
13
  - if type is "text", "content" column must be present.
 
 
 
14
  """
15
 
16
- def read(self, file_path: str) -> List[Dict[str, Any]]:
17
- with open(file_path, "rb") as f:
18
- data = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- if not isinstance(data, list):
21
- raise ValueError("Pickle file must contain a list of documents.")
22
 
23
- for doc in data:
24
- if not isinstance(doc, dict):
25
- raise ValueError("Every item in the list must be a dict.")
26
- assert "type" in doc, f"Missing 'type' in document: {doc}"
27
- if doc.get("type") == "text" and self.text_column not in doc:
28
- raise ValueError(f"Missing '{self.text_column}' in document: {doc}")
29
 
30
- return self.filter(data)
 
 
 
1
  import pickle
2
+ from typing import List, Union
3
+
4
+ import pandas as pd
5
+ import ray
6
+ from ray.data import Dataset
7
 
8
  from graphgen.bases.base_reader import BaseReader
9
+ from graphgen.utils import logger
10
 
11
 
12
  class PickleReader(BaseReader):
13
  """
14
+ Read pickle files, requiring the schema to be restored to List[Dict[str, Any]].
15
+ Each pickle file should contain a list of dictionaries with at least:
 
16
  - type: The type of the document (e.g., "text", "image", etc.)
17
  - if type is "text", "content" column must be present.
18
+
19
+ Note: Uses ray.data.read_binary_files as ray.data.read_pickle is not available.
20
+ For Ray >= 2.5, consider using read_pickle if available in your version.
21
  """
22
 
23
+ def read(
24
+ self,
25
+ input_path: Union[str, List[str]],
26
+ ) -> Dataset:
27
+ """
28
+ Read Pickle files using Ray Data.
29
+
30
+ :param input_path: Path to pickle file or list of pickle files.
31
+ :return: Ray Dataset containing validated documents.
32
+ """
33
+ if not ray.is_initialized():
34
+ ray.init()
35
+
36
+ # Use read_binary_files as a reliable alternative to read_pickle
37
+ ds = ray.data.read_binary_files(input_path, include_paths=True)
38
+
39
+ # Deserialize pickle files and flatten into individual records
40
+ def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame:
41
+ all_records = []
42
+ for _, row in batch.iterrows():
43
+ try:
44
+ # Load pickle data from bytes
45
+ data = pickle.loads(row["bytes"])
46
+
47
+ # Validate structure
48
+ if not isinstance(data, list):
49
+ logger.error(
50
+ "Pickle file {row['path']} must contain a list, got {type(data)}"
51
+ )
52
+ continue
53
+
54
+ if not all(isinstance(item, dict) for item in data):
55
+ logger.error(
56
+ "Pickle file {row['path']} must contain a list of dictionaries"
57
+ )
58
+ continue
59
+
60
+ # Flatten: each dict in the list becomes a separate row
61
+ all_records.extend(data)
62
+ except Exception as e:
63
+ logger.error(
64
+ "Failed to deserialize pickle file %s: %s", row["path"], str(e)
65
+ )
66
+ continue
67
+
68
+ return pd.DataFrame(all_records)
69
 
70
+ # Apply deserialization and flattening
71
+ ds = ds.map_batches(deserialize_batch, batch_format="pandas")
72
 
73
+ # Validate the schema
74
+ ds = ds.map_batches(self._validate_batch, batch_format="pandas")
 
 
 
 
75
 
76
+ # Filter valid items
77
+ ds = ds.filter(self._should_keep_item)
78
+ return ds
graphgen/models/reader/rdf_reader.py CHANGED
@@ -1,48 +1,128 @@
1
- from typing import Any, Dict, List
 
2
 
 
3
  import rdflib
 
4
  from rdflib import Literal
5
  from rdflib.util import guess_format
6
 
7
  from graphgen.bases.base_reader import BaseReader
 
8
 
9
 
10
  class RDFReader(BaseReader):
11
  """
12
  Reader for RDF files that extracts triples and represents them as dictionaries.
 
 
13
  """
14
 
15
- def read(self, file_path: str) -> List[Dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  g = rdflib.Graph()
17
- fmt = guess_format(file_path)
 
18
  try:
19
- g.parse(file_path, format=fmt)
20
  except Exception as e:
21
  raise ValueError(f"Cannot parse RDF file {file_path}: {e}") from e
22
 
23
  docs: List[Dict[str, Any]] = []
24
- text_col = self.text_column
25
 
 
26
  for subj in set(g.subjects()):
27
  literals = []
28
  props = {}
 
 
29
  for _, pred, obj in g.triples((subj, None, None)):
30
  pred_str = str(pred)
 
 
 
31
  if isinstance(obj, Literal):
32
- literals.append(str(obj))
33
- props.setdefault(pred_str, []).append(str(obj))
 
 
34
 
 
35
  text = " ".join(literals).strip()
36
  if not text:
37
- raise ValueError(
38
- f"Subject {subj} has no literal values; "
39
- f"missing '{text_col}' for text column."
 
 
40
  )
41
 
42
- doc = {"id": str(subj), text_col: text, "properties": props}
 
 
 
 
 
 
43
  docs.append(doc)
44
 
45
  if not docs:
46
- raise ValueError("RDF file contains no valid documents.")
47
 
48
- return self.filter(docs)
 
1
+ from pathlib import Path
2
+ from typing import Any, Dict, List, Union
3
 
4
+ import ray
5
  import rdflib
6
+ from ray.data import Dataset
7
  from rdflib import Literal
8
  from rdflib.util import guess_format
9
 
10
  from graphgen.bases.base_reader import BaseReader
11
+ from graphgen.utils import logger
12
 
13
 
14
  class RDFReader(BaseReader):
15
  """
16
  Reader for RDF files that extracts triples and represents them as dictionaries.
17
+
18
+ Uses Ray Data for distributed processing of multiple RDF files.
19
  """
20
 
21
+ def __init__(self, *, text_column: str = "content", **kwargs):
22
+ """
23
+ Initialize RDFReader.
24
+
25
+ :param text_column: The column name for text content (default: "content").
26
+ """
27
+ super().__init__(**kwargs)
28
+ self.text_column = text_column
29
+
30
+ def read(
31
+ self,
32
+ input_path: Union[str, List[str]],
33
+ ) -> Dataset:
34
+ """
35
+ Read RDF file(s) using Ray Data.
36
+
37
+ :param input_path: Path to RDF file or list of RDF files.
38
+ :return: Ray Dataset containing extracted documents.
39
+ """
40
+ if not ray.is_initialized():
41
+ ray.init()
42
+
43
+ # Ensure input_path is a list to prevent Ray from splitting string into characters
44
+ if isinstance(input_path, str):
45
+ input_path = [input_path]
46
+
47
+ # Create dataset from file paths
48
+ paths_ds = ray.data.from_items(input_path)
49
+
50
+ def process_rdf(row: Dict[str, Any]) -> List[Dict[str, Any]]:
51
+ """Process a single RDF file and return list of documents."""
52
+ try:
53
+ file_path = row["item"]
54
+ return self._parse_rdf_file(Path(file_path))
55
+ except Exception as e:
56
+ logger.error(
57
+ "Failed to process RDF file %s: %s", row.get("item", "unknown"), e
58
+ )
59
+ return []
60
+
61
+ # Process files in parallel and flatten results
62
+ docs_ds = paths_ds.flat_map(process_rdf)
63
+
64
+ # Filter valid documents
65
+ docs_ds = docs_ds.filter(self._should_keep_item)
66
+
67
+ return docs_ds
68
+
69
+ def _parse_rdf_file(self, file_path: Path) -> List[Dict[str, Any]]:
70
+ """
71
+ Parse a single RDF file and extract documents.
72
+
73
+ :param file_path: Path to RDF file.
74
+ :return: List of document dictionaries.
75
+ """
76
+ if not file_path.is_file():
77
+ raise FileNotFoundError(f"RDF file not found: {file_path}")
78
+
79
  g = rdflib.Graph()
80
+ fmt = guess_format(str(file_path))
81
+
82
  try:
83
+ g.parse(str(file_path), format=fmt)
84
  except Exception as e:
85
  raise ValueError(f"Cannot parse RDF file {file_path}: {e}") from e
86
 
87
  docs: List[Dict[str, Any]] = []
 
88
 
89
+ # Process each unique subject in the RDF graph
90
  for subj in set(g.subjects()):
91
  literals = []
92
  props = {}
93
+
94
+ # Extract all triples for this subject
95
  for _, pred, obj in g.triples((subj, None, None)):
96
  pred_str = str(pred)
97
+ obj_str = str(obj)
98
+
99
+ # Collect literal values as text content
100
  if isinstance(obj, Literal):
101
+ literals.append(obj_str)
102
+
103
+ # Store all properties (including non-literals)
104
+ props.setdefault(pred_str, []).append(obj_str)
105
 
106
+ # Join all literal values as the text content
107
  text = " ".join(literals).strip()
108
  if not text:
109
+ logger.warning(
110
+ "Subject %s in %s has no literal values; document will have empty '%s' field.",
111
+ subj,
112
+ file_path,
113
+ self.text_column,
114
  )
115
 
116
+ # Create document dictionary
117
+ doc = {
118
+ "id": str(subj),
119
+ self.text_column: text,
120
+ "properties": props,
121
+ "source_file": str(file_path),
122
+ }
123
  docs.append(doc)
124
 
125
  if not docs:
126
+ logger.warning("RDF file %s contains no valid documents.", file_path)
127
 
128
+ return docs
graphgen/models/reader/txt_reader.py CHANGED
@@ -1,10 +1,32 @@
1
- from typing import Any, Dict, List
 
 
 
2
 
3
  from graphgen.bases.base_reader import BaseReader
4
 
5
 
6
  class TXTReader(BaseReader):
7
- def read(self, file_path: str) -> List[Dict[str, Any]]:
8
- with open(file_path, "r", encoding="utf-8") as f:
9
- docs = [{"type": "text", self.text_column: f.read()}]
10
- return self.filter(docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import ray
4
+ from ray.data import Dataset
5
 
6
  from graphgen.bases.base_reader import BaseReader
7
 
8
 
9
  class TXTReader(BaseReader):
10
+ def read(
11
+ self,
12
+ input_path: Union[str, List[str]],
13
+ ) -> Dataset:
14
+ """
15
+ Read text files from the specified input path.
16
+ :param input_path: Path to the input text file or list of text files.
17
+ :return: Ray Dataset containing the read text data.
18
+ """
19
+ docs_ds = ray.data.read_binary_files(
20
+ input_path,
21
+ include_paths=False,
22
+ )
23
+
24
+ docs_ds = docs_ds.map(
25
+ lambda row: {
26
+ "type": "text",
27
+ self.text_column: row["bytes"].decode("utf-8"),
28
+ }
29
+ )
30
+
31
+ docs_ds = docs_ds.filter(self._should_keep_item)
32
+ return docs_ds
graphgen/models/splitter/character_splitter.py CHANGED
@@ -17,7 +17,7 @@ class CharacterSplitter(BaseSplitter):
17
 
18
  def split_text(self, text: str) -> List[str]:
19
  """Split incoming text and return chunks."""
20
- # First we naively split the large input into a bunch of smaller ones.
21
  separator = (
22
  self._separator if self._is_separator_regex else re.escape(self._separator)
23
  )
 
17
 
18
  def split_text(self, text: str) -> List[str]:
19
  """Split incoming text and return chunks."""
20
+ # First we naively chunk the large input into a bunch of smaller ones.
21
  separator = (
22
  self._separator if self._is_separator_regex else re.escape(self._separator)
23
  )
graphgen/models/splitter/markdown_splitter.py CHANGED
@@ -6,12 +6,12 @@ from graphgen.models.splitter.recursive_character_splitter import (
6
 
7
 
8
  class MarkdownTextRefSplitter(RecursiveCharacterSplitter):
9
- """Attempts to split the text along Markdown-formatted headings."""
10
 
11
  def __init__(self, **kwargs: Any) -> None:
12
  """Initialize a MarkdownTextRefSplitter."""
13
  separators = [
14
- # First, try to split along Markdown headings (starting with level 2)
15
  "\n#{1,6} ",
16
  # Note the alternative syntax for headings (below) is not handled here
17
  # Heading level 2
 
6
 
7
 
8
  class MarkdownTextRefSplitter(RecursiveCharacterSplitter):
9
+ """Attempts to chunk the text along Markdown-formatted headings."""
10
 
11
  def __init__(self, **kwargs: Any) -> None:
12
  """Initialize a MarkdownTextRefSplitter."""
13
  separators = [
14
+ # First, try to chunk along Markdown headings (starting with level 2)
15
  "\n#{1,6} ",
16
  # Note the alternative syntax for headings (below) is not handled here
17
  # Heading level 2
graphgen/models/splitter/recursive_character_splitter.py CHANGED
@@ -7,7 +7,7 @@ from graphgen.bases.base_splitter import BaseSplitter
7
  class RecursiveCharacterSplitter(BaseSplitter):
8
  """Splitting text by recursively look at characters.
9
 
10
- Recursively tries to split by different characters to find one that works.
11
  """
12
 
13
  def __init__(
@@ -88,7 +88,7 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterSplitter):
88
  def _split_text_with_regex_from_end(
89
  self, text: str, separator: str, keep_separator: bool
90
  ) -> List[str]:
91
- # Now that we have the separator, split the text
92
  if separator:
93
  if keep_separator:
94
  # The parentheses in the pattern keep the delimiters in the result.
 
7
  class RecursiveCharacterSplitter(BaseSplitter):
8
  """Splitting text by recursively look at characters.
9
 
10
+ Recursively tries to chunk by different characters to find one that works.
11
  """
12
 
13
  def __init__(
 
88
  def _split_text_with_regex_from_end(
89
  self, text: str, separator: str, keep_separator: bool
90
  ) -> List[str]:
91
+ # Now that we have the separator, chunk the text
92
  if separator:
93
  if keep_separator:
94
  # The parentheses in the pattern keep the delimiters in the result.
graphgen/models/storage/__init__.py CHANGED
@@ -1,3 +1,6 @@
1
- from .json_storage import JsonKVStorage, JsonListStorage
2
- from .networkx_storage import NetworkXStorage
 
 
 
3
  from .rocksdb_cache import RocksDBCache
 
1
+ from graphgen.models.storage.graph.kuzu_storage import KuzuStorage
2
+ from graphgen.models.storage.graph.networkx_storage import NetworkXStorage
3
+ from graphgen.models.storage.kv.json_storage import JsonKVStorage
4
+ from graphgen.models.storage.kv.rocksdb_storage import RocksDBKVStorage
5
+
6
  from .rocksdb_cache import RocksDBCache
graphgen/{configs → models/storage/graph}/__init__.py RENAMED
File without changes
graphgen/models/storage/graph/kuzu_storage.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import shutil
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ try:
8
+ import kuzu
9
+ except ImportError:
10
+ kuzu = None
11
+
12
+ from graphgen.bases.base_storage import BaseGraphStorage
13
+
14
+
15
+ @dataclass
16
+ class KuzuStorage(BaseGraphStorage):
17
+ """
18
+ Graph storage implementation based on KuzuDB.
19
+ Since KuzuDB is a structured graph database and GraphGen uses dynamic dictionaries for properties,
20
+ we map the data to a generic schema:
21
+ - Node Table 'Entity': {id: STRING, data: STRING (JSON)}
22
+ - Rel Table 'Relation': {FROM Entity TO Entity, data: STRING (JSON)}
23
+ """
24
+
25
+ working_dir: str = None
26
+ namespace: str = None
27
+ _db: Any = None
28
+ _conn: Any = None
29
+
30
+ def __post_init__(self):
31
+ if kuzu is None:
32
+ raise ImportError(
33
+ "KuzuDB is not installed. Please install it via `pip install kuzu`."
34
+ )
35
+
36
+ self.db_path = os.path.join(self.working_dir, f"{self.namespace}_kuzu")
37
+ self._init_db()
38
+
39
+ def _init_db(self):
40
+ # KuzuDB automatically creates the directory
41
+ self._db = kuzu.Database(self.db_path)
42
+ self._conn = kuzu.Connection(self._db)
43
+ self._init_schema()
44
+ print(f"KuzuDB initialized at {self.db_path}")
45
+
46
+ def _init_schema(self):
47
+ """Initialize the generic Node and Edge tables if they don't exist."""
48
+ # Check and create Node table
49
+ try:
50
+ # We use a generic table name "Entity" to store all nodes
51
+ self._conn.execute(
52
+ "CREATE NODE TABLE Entity(id STRING, data STRING, PRIMARY KEY(id))"
53
+ )
54
+ print("Created KuzuDB Node Table 'Entity'")
55
+ except RuntimeError as e:
56
+ # Usually throws if table exists, verify safely or ignore
57
+ print("Node Table 'Entity' already exists or error:", e)
58
+
59
+ # Check and create Edge table
60
+ try:
61
+ # We use a generic table name "Relation" to store all edges
62
+ self._conn.execute(
63
+ "CREATE REL TABLE Relation(FROM Entity TO Entity, data STRING)"
64
+ )
65
+ print("Created KuzuDB Rel Table 'Relation'")
66
+ except RuntimeError as e:
67
+ print("Rel Table 'Relation' already exists or error:", e)
68
+
69
+ def index_done_callback(self):
70
+ """KuzuDB is ACID, changes are immediate, but we can verify generic persistence here."""
71
+
72
+ def has_node(self, node_id: str) -> bool:
73
+ result = self._conn.execute(
74
+ "MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id}
75
+ )
76
+ count = result.get_next()[0]
77
+ return count > 0
78
+
79
+ def has_edge(self, source_node_id: str, target_node_id: str):
80
+ result = self._conn.execute(
81
+ "MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) RETURN count(e)",
82
+ {"src": source_node_id, "dst": target_node_id},
83
+ )
84
+ count = result.get_next()[0]
85
+ return count > 0
86
+
87
+ def node_degree(self, node_id: str) -> int:
88
+ # Calculate total degree (incoming + outgoing)
89
+ query = """
90
+ MATCH (a:Entity {id: $id})-[e:Relation]-(b:Entity)
91
+ RETURN count(e)
92
+ """
93
+ result = self._conn.execute(query, {"id": node_id})
94
+ if result.has_next():
95
+ return result.get_next()[0]
96
+ return 0
97
+
98
+ def edge_degree(self, src_id: str, tgt_id: str) -> int:
99
+ # In this context, usually checks existence or multiplicity.
100
+ # Kuzu supports multi-edges, so we count them.
101
+ query = """
102
+ MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
103
+ RETURN count(e)
104
+ """
105
+ result = self._conn.execute(query, {"src": src_id, "dst": tgt_id})
106
+ if result.has_next():
107
+ return result.get_next()[0]
108
+ return 0
109
+
110
+ def get_node(self, node_id: str) -> Any:
111
+ result = self._conn.execute(
112
+ "MATCH (a:Entity {id: $id}) RETURN a.data", {"id": node_id}
113
+ )
114
+ if result.has_next():
115
+ data_str = result.get_next()[0]
116
+ return json.loads(data_str) if data_str else {}
117
+ return None
118
+
119
+ def update_node(self, node_id: str, node_data: dict[str, str]):
120
+ current_data = self.get_node(node_id)
121
+ if current_data is None:
122
+ print(f"Node {node_id} not found for update.")
123
+ return
124
+
125
+ # Merge existing data with new data
126
+ current_data.update(node_data)
127
+ json_data = json.dumps(current_data, ensure_ascii=False)
128
+
129
+ self._conn.execute(
130
+ "MATCH (a:Entity {id: $id}) SET a.data = $data",
131
+ {"id": node_id, "data": json_data},
132
+ )
133
+
134
+ def get_all_nodes(self) -> Any:
135
+ """Returns List[Tuple[id, data_dict]]"""
136
+ result = self._conn.execute("MATCH (a:Entity) RETURN a.id, a.data")
137
+ nodes = []
138
+ while result.has_next():
139
+ row = result.get_next()
140
+ nodes.append((row[0], json.loads(row[1])))
141
+ return nodes
142
+
143
+ def get_edge(self, source_node_id: str, target_node_id: str):
144
+ # Warning: If multiple edges exist, this returns the first one found
145
+ query = """
146
+ MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
147
+ RETURN e.data
148
+ """
149
+ result = self._conn.execute(
150
+ query, {"src": source_node_id, "dst": target_node_id}
151
+ )
152
+ if result.has_next():
153
+ data_str = result.get_next()[0]
154
+ return json.loads(data_str) if data_str else {}
155
+ return None
156
+
157
+ def update_edge(
158
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
159
+ ):
160
+ current_data = self.get_edge(source_node_id, target_node_id)
161
+ if current_data is None:
162
+ print(f"Edge {source_node_id}->{target_node_id} not found for update.")
163
+ return
164
+
165
+ current_data.update(edge_data)
166
+ json_data = json.dumps(current_data, ensure_ascii=False)
167
+
168
+ query = """
169
+ MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
170
+ SET e.data = $data
171
+ """
172
+ self._conn.execute(
173
+ query, {"src": source_node_id, "dst": target_node_id, "data": json_data}
174
+ )
175
+
176
+ def get_all_edges(self) -> Any:
177
+ """Returns List[Tuple[src, dst, data_dict]]"""
178
+ query = "MATCH (a:Entity)-[e:Relation]->(b:Entity) RETURN a.id, b.id, e.data"
179
+ result = self._conn.execute(query)
180
+ edges = []
181
+ while result.has_next():
182
+ row = result.get_next()
183
+ edges.append((row[0], row[1], json.loads(row[2])))
184
+ return edges
185
+
186
+ def get_node_edges(self, source_node_id: str) -> Any:
187
+ """Returns generic edges connected to this node (outgoing)"""
188
+ query = """
189
+ MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity)
190
+ RETURN a.id, b.id, e.data
191
+ """
192
+ result = self._conn.execute(query, {"src": source_node_id})
193
+ edges = []
194
+ while result.has_next():
195
+ row = result.get_next()
196
+ edges.append((row[0], row[1], json.loads(row[2])))
197
+ return edges
198
+
199
+ def upsert_node(self, node_id: str, node_data: dict[str, str]):
200
+ """
201
+ Insert or Update node.
202
+ Kuzu supports MERGE clause (similar to Neo4j) to handle upserts.
203
+ """
204
+ json_data = json.dumps(node_data, ensure_ascii=False)
205
+ query = """
206
+ MERGE (a:Entity {id: $id})
207
+ ON MATCH SET a.data = $data
208
+ ON CREATE SET a.data = $data
209
+ """
210
+ self._conn.execute(query, {"id": node_id, "data": json_data})
211
+
212
+ def upsert_edge(
213
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
214
+ ):
215
+ """
216
+ Insert or Update edge.
217
+ Note: We explicitly ensure nodes exist before merging the edge to avoid errors,
218
+ although GraphGen generally creates nodes before edges.
219
+ """
220
+ # Ensure source node exists
221
+ if not self.has_node(source_node_id):
222
+ self.upsert_node(source_node_id, {})
223
+ # Ensure target node exists
224
+ if not self.has_node(target_node_id):
225
+ self.upsert_node(target_node_id, {})
226
+
227
+ json_data = json.dumps(edge_data, ensure_ascii=False)
228
+ query = """
229
+ MATCH (a:Entity {id: $src}), (b:Entity {id: $dst})
230
+ MERGE (a)-[e:Relation]->(b)
231
+ ON MATCH SET e.data = $data
232
+ ON CREATE SET e.data = $data
233
+ """
234
+ self._conn.execute(
235
+ query, {"src": source_node_id, "dst": target_node_id, "data": json_data}
236
+ )
237
+
238
+ def delete_node(self, node_id: str):
239
+ # DETACH DELETE removes the node and all connected edges
240
+ query = "MATCH (a:Entity {id: $id}) DETACH DELETE a"
241
+ self._conn.execute(query, {"id": node_id})
242
+ print(f"Node {node_id} deleted from KuzuDB.")
243
+
244
+ def clear(self):
245
+ """Clear all data but keep schema (or drop tables)."""
246
+ self._conn.execute("MATCH (n) DETACH DELETE n")
247
+ print(f"Graph {self.namespace} cleared.")
248
+
249
+ def reload(self):
250
+ """For databases that need reloading, KuzuDB auto-manages this."""
251
+
252
+ def drop(self):
253
+ """Completely remove the database folder."""
254
+ if self.db_path and os.path.exists(self.db_path):
255
+ shutil.rmtree(self.db_path)
256
+ print(f"Dropped KuzuDB at {self.db_path}")
graphgen/models/storage/{networkx_storage.py → graph/networkx_storage.py} RENAMED
@@ -6,7 +6,6 @@ from typing import Any, Optional, Union, cast
6
  import networkx as nx
7
 
8
  from graphgen.bases.base_storage import BaseGraphStorage
9
- from graphgen.utils import logger
10
 
11
 
12
  @dataclass
@@ -19,11 +18,6 @@ class NetworkXStorage(BaseGraphStorage):
19
 
20
  @staticmethod
21
  def write_nx_graph(graph: nx.Graph, file_name):
22
- logger.info(
23
- "Writing graph with %d nodes, %d edges",
24
- graph.number_of_nodes(),
25
- graph.number_of_edges(),
26
- )
27
  nx.write_graphml(graph, file_name)
28
 
29
  @staticmethod
@@ -82,12 +76,11 @@ class NetworkXStorage(BaseGraphStorage):
82
  self.working_dir, f"{self.namespace}.graphml"
83
  )
84
  preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
85
- if preloaded_graph is not None:
86
- logger.info(
87
- "Loaded graph from %s with %d nodes, %d edges",
88
- self._graphml_xml_file,
89
- preloaded_graph.number_of_nodes(),
90
- preloaded_graph.number_of_edges(),
91
  )
92
  self._graph = preloaded_graph or nx.Graph()
93
 
@@ -133,7 +126,7 @@ class NetworkXStorage(BaseGraphStorage):
133
  if self._graph.has_node(node_id):
134
  self._graph.nodes[node_id].update(node_data)
135
  else:
136
- logger.warning("Node %s not found in the graph for update.", node_id)
137
 
138
  def upsert_edge(
139
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
@@ -146,10 +139,8 @@ class NetworkXStorage(BaseGraphStorage):
146
  if self._graph.has_edge(source_node_id, target_node_id):
147
  self._graph.edges[(source_node_id, target_node_id)].update(edge_data)
148
  else:
149
- logger.warning(
150
- "Edge %s -> %s not found in the graph for update.",
151
- source_node_id,
152
- target_node_id,
153
  )
154
 
155
  def delete_node(self, node_id: str):
@@ -160,13 +151,19 @@ class NetworkXStorage(BaseGraphStorage):
160
  """
161
  if self._graph.has_node(node_id):
162
  self._graph.remove_node(node_id)
163
- logger.info("Node %s deleted from the graph.", node_id)
164
  else:
165
- logger.warning("Node %s not found in the graph for deletion.", node_id)
166
 
167
  def clear(self):
168
  """
169
  Clear the graph by removing all nodes and edges.
170
  """
171
  self._graph.clear()
172
- logger.info("Graph %s cleared.", self.namespace)
 
 
 
 
 
 
 
6
  import networkx as nx
7
 
8
  from graphgen.bases.base_storage import BaseGraphStorage
 
9
 
10
 
11
  @dataclass
 
18
 
19
  @staticmethod
20
  def write_nx_graph(graph: nx.Graph, file_name):
 
 
 
 
 
21
  nx.write_graphml(graph, file_name)
22
 
23
  @staticmethod
 
76
  self.working_dir, f"{self.namespace}.graphml"
77
  )
78
  preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
79
+ if preloaded_graph:
80
+ print(
81
+ f"Loaded graph from {self._graphml_xml_file} with "
82
+ f"{preloaded_graph.number_of_nodes()} nodes, "
83
+ f"{preloaded_graph.number_of_edges()} edges"
 
84
  )
85
  self._graph = preloaded_graph or nx.Graph()
86
 
 
126
  if self._graph.has_node(node_id):
127
  self._graph.nodes[node_id].update(node_data)
128
  else:
129
+ print(f"Node {node_id} not found in the graph for update.")
130
 
131
  def upsert_edge(
132
  self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
 
139
  if self._graph.has_edge(source_node_id, target_node_id):
140
  self._graph.edges[(source_node_id, target_node_id)].update(edge_data)
141
  else:
142
+ print(
143
+ f"Edge {source_node_id} -> {target_node_id} not found in the graph for update."
 
 
144
  )
145
 
146
  def delete_node(self, node_id: str):
 
151
  """
152
  if self._graph.has_node(node_id):
153
  self._graph.remove_node(node_id)
154
+ print(f"Node {node_id} deleted from the graph.")
155
  else:
156
+ print(f"Node {node_id} not found in the graph for deletion.")
157
 
158
  def clear(self):
159
  """
160
  Clear the graph by removing all nodes and edges.
161
  """
162
  self._graph.clear()
163
+ print(f"Graph {self.namespace} cleared.")
164
+
165
+ def reload(self):
166
+ """
167
+ Reload the graph from the GraphML file.
168
+ """
169
+ self.__post_init__()
graphgen/models/storage/kv/__init__.py ADDED
File without changes