Spaces:
Running
Running
github-actions[bot]
commited on
Commit
·
31086ae
1
Parent(s):
10ba08f
Auto-sync from demo at Tue Dec 16 08:21:05 UTC 2025
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +134 -172
- graphgen/bases/__init__.py +3 -7
- graphgen/bases/base_llm_wrapper.py +0 -6
- graphgen/bases/base_operator.py +57 -0
- graphgen/bases/base_partitioner.py +22 -27
- graphgen/bases/base_reader.py +57 -41
- graphgen/bases/base_splitter.py +3 -3
- graphgen/bases/base_storage.py +6 -17
- graphgen/bases/datatypes.py +44 -0
- graphgen/{operators/init → common}/__init__.py +1 -0
- graphgen/{operators/init → common}/init_llm.py +125 -29
- graphgen/common/init_storage.py +262 -0
- graphgen/configs/aggregated_config.yaml +0 -41
- graphgen/configs/atomic_config.yaml +0 -31
- graphgen/configs/cot_config.yaml +0 -33
- graphgen/configs/multi_hop_config.yaml +0 -34
- graphgen/configs/schema_guided_extraction_config.yaml +0 -20
- graphgen/configs/search_dna_config.yaml +0 -17
- graphgen/configs/search_protein_config.yaml +0 -15
- graphgen/configs/search_rna_config.yaml +0 -14
- graphgen/configs/vqa_config.yaml +0 -32
- graphgen/engine.py +191 -106
- graphgen/graphgen.py +0 -295
- graphgen/models/__init__.py +7 -2
- graphgen/models/extractor/schema_guided_extractor.py +3 -5
- graphgen/models/generator/vqa_generator.py +2 -2
- graphgen/models/llm/local/sglang_wrapper.py +0 -12
- graphgen/models/llm/local/vllm_wrapper.py +35 -47
- graphgen/models/partitioner/anchor_bfs_partitioner.py +9 -14
- graphgen/models/partitioner/bfs_partitioner.py +4 -9
- graphgen/models/partitioner/dfs_partitioner.py +5 -9
- graphgen/models/partitioner/ece_partitioner.py +19 -24
- graphgen/models/partitioner/leiden_partitioner.py +5 -9
- graphgen/models/reader/__init__.py +0 -1
- graphgen/models/reader/csv_reader.py +14 -11
- graphgen/models/reader/json_reader.py +41 -14
- graphgen/models/reader/jsonl_reader.py +0 -30
- graphgen/models/reader/parquet_reader.py +16 -10
- graphgen/models/reader/pdf_reader.py +35 -20
- graphgen/models/reader/pickle_reader.py +64 -16
- graphgen/models/reader/rdf_reader.py +93 -13
- graphgen/models/reader/txt_reader.py +27 -5
- graphgen/models/splitter/character_splitter.py +1 -1
- graphgen/models/splitter/markdown_splitter.py +2 -2
- graphgen/models/splitter/recursive_character_splitter.py +2 -2
- graphgen/models/storage/__init__.py +5 -2
- graphgen/{configs → models/storage/graph}/__init__.py +0 -0
- graphgen/models/storage/graph/kuzu_storage.py +256 -0
- graphgen/models/storage/{networkx_storage.py → graph/networkx_storage.py} +17 -20
- graphgen/models/storage/kv/__init__.py +0 -0
app.py
CHANGED
|
@@ -5,14 +5,12 @@ import tempfile
|
|
| 5 |
from importlib.resources import files
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
-
import
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
|
| 11 |
-
from graphgen.engine import
|
| 12 |
-
from graphgen.
|
| 13 |
-
from graphgen.
|
| 14 |
-
from graphgen.models.llm.limitter import RPM, TPM
|
| 15 |
-
from graphgen.utils import set_logger
|
| 16 |
from webui.base import WebuiParams
|
| 17 |
from webui.i18n import Translate
|
| 18 |
from webui.i18n import gettext as _
|
|
@@ -22,7 +20,6 @@ from webui.utils import cleanup_workspace, count_tokens, preview_file, setup_wor
|
|
| 22 |
root_dir = files("webui").parent
|
| 23 |
sys.path.append(root_dir)
|
| 24 |
|
| 25 |
-
|
| 26 |
load_dotenv()
|
| 27 |
|
| 28 |
css = """
|
|
@@ -34,131 +31,136 @@ css = """
|
|
| 34 |
"""
|
| 35 |
|
| 36 |
|
| 37 |
-
def
|
| 38 |
-
# Set up working directory
|
| 39 |
-
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
|
| 40 |
-
set_logger(log_file, if_stream=True)
|
| 41 |
-
os.environ.update({k: str(v) for k, v in env.items()})
|
| 42 |
-
|
| 43 |
-
tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
|
| 44 |
-
synthesizer_llm_client = OpenAIClient(
|
| 45 |
-
model=env.get("SYNTHESIZER_MODEL", ""),
|
| 46 |
-
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
|
| 47 |
-
api_key=env.get("SYNTHESIZER_API_KEY", ""),
|
| 48 |
-
request_limit=True,
|
| 49 |
-
rpm=RPM(env.get("RPM", 1000)),
|
| 50 |
-
tpm=TPM(env.get("TPM", 50000)),
|
| 51 |
-
tokenizer=tokenizer_instance,
|
| 52 |
-
)
|
| 53 |
-
trainee_llm_client = OpenAIClient(
|
| 54 |
-
model=env.get("TRAINEE_MODEL", ""),
|
| 55 |
-
base_url=env.get("TRAINEE_BASE_URL", ""),
|
| 56 |
-
api_key=env.get("TRAINEE_API_KEY", ""),
|
| 57 |
-
request_limit=True,
|
| 58 |
-
rpm=RPM(env.get("RPM", 1000)),
|
| 59 |
-
tpm=TPM(env.get("TPM", 50000)),
|
| 60 |
-
tokenizer=tokenizer_instance,
|
| 61 |
-
)
|
| 62 |
-
|
| 63 |
-
graph_gen = GraphGen(
|
| 64 |
-
working_dir=working_dir,
|
| 65 |
-
tokenizer_instance=tokenizer_instance,
|
| 66 |
-
synthesizer_llm_client=synthesizer_llm_client,
|
| 67 |
-
trainee_llm_client=trainee_llm_client,
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
return graph_gen
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
# pylint: disable=too-many-statements
|
| 74 |
-
def run_graphgen(params: WebuiParams, progress=gr.Progress()):
|
| 75 |
-
def sum_tokens(client):
|
| 76 |
-
return sum(u["total_tokens"] for u in client.token_usage)
|
| 77 |
-
|
| 78 |
method = params.partition_method
|
| 79 |
if method == "dfs":
|
| 80 |
-
|
| 81 |
"max_units_per_community": params.dfs_max_units,
|
| 82 |
}
|
| 83 |
-
|
| 84 |
-
|
| 85 |
"max_units_per_community": params.bfs_max_units,
|
| 86 |
}
|
| 87 |
-
|
| 88 |
-
|
| 89 |
"max_size": params.leiden_max_size,
|
| 90 |
"use_lcc": params.leiden_use_lcc,
|
| 91 |
"random_seed": params.leiden_random_seed,
|
| 92 |
}
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
|
|
|
|
| 102 |
{
|
| 103 |
-
"
|
| 104 |
-
"
|
|
|
|
|
|
|
| 105 |
"params": {
|
| 106 |
-
"
|
| 107 |
},
|
| 108 |
},
|
| 109 |
{
|
| 110 |
-
"
|
| 111 |
-
"
|
| 112 |
-
"
|
|
|
|
|
|
|
| 113 |
"params": {
|
| 114 |
"chunk_size": params.chunk_size,
|
| 115 |
"chunk_overlap": params.chunk_overlap,
|
| 116 |
},
|
| 117 |
},
|
| 118 |
{
|
| 119 |
-
"
|
| 120 |
-
"
|
| 121 |
-
"
|
|
|
|
|
|
|
| 122 |
},
|
| 123 |
]
|
| 124 |
|
|
|
|
|
|
|
|
|
|
| 125 |
if params.if_trainee_model:
|
| 126 |
-
|
| 127 |
-
{
|
| 128 |
-
"name": "quiz_and_judge",
|
| 129 |
-
"deps": ["build_kg"],
|
| 130 |
-
"op_key": "quiz_and_judge",
|
| 131 |
-
"params": {"quiz_samples": params.quiz_samples, "re_judge": True},
|
| 132 |
-
}
|
| 133 |
-
)
|
| 134 |
-
pipeline.append(
|
| 135 |
{
|
| 136 |
-
"
|
| 137 |
-
"
|
| 138 |
-
"
|
|
|
|
|
|
|
| 139 |
"params": {
|
| 140 |
-
"
|
| 141 |
-
"
|
| 142 |
},
|
| 143 |
}
|
| 144 |
)
|
| 145 |
-
|
| 146 |
-
|
| 147 |
{
|
| 148 |
-
"
|
| 149 |
-
"
|
| 150 |
-
"
|
| 151 |
-
"
|
| 152 |
-
|
| 153 |
-
"method_params": partition_params,
|
| 154 |
-
},
|
| 155 |
}
|
| 156 |
)
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
{
|
| 159 |
-
"
|
| 160 |
-
"
|
| 161 |
-
"
|
|
|
|
|
|
|
| 162 |
"params": {
|
| 163 |
"method": params.mode,
|
| 164 |
"data_format": params.data_format,
|
|
@@ -166,88 +168,50 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()):
|
|
| 166 |
}
|
| 167 |
)
|
| 168 |
|
| 169 |
-
config = {
|
| 170 |
-
"if_trainee_model": params.if_trainee_model,
|
| 171 |
-
"read": {"input_file": params.upload_file},
|
| 172 |
-
"pipeline": pipeline,
|
| 173 |
-
}
|
| 174 |
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
"TRAINEE_MODEL": params.trainee_model,
|
| 181 |
-
"SYNTHESIZER_API_KEY": params.api_key,
|
| 182 |
-
"TRAINEE_API_KEY": params.trainee_api_key,
|
| 183 |
-
"RPM": params.rpm,
|
| 184 |
-
"TPM": params.tpm,
|
| 185 |
-
}
|
| 186 |
|
| 187 |
-
|
| 188 |
-
test_api_connection(
|
| 189 |
-
env["SYNTHESIZER_BASE_URL"],
|
| 190 |
-
env["SYNTHESIZER_API_KEY"],
|
| 191 |
-
env["SYNTHESIZER_MODEL"],
|
| 192 |
-
)
|
| 193 |
-
if config["if_trainee_model"]:
|
| 194 |
-
test_api_connection(
|
| 195 |
-
env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
|
| 196 |
-
)
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
graph_gen.clear()
|
| 201 |
-
graph_gen.progress_bar = progress
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
[data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)]
|
| 228 |
-
]
|
| 229 |
-
new_df = pd.DataFrame(_update_data, columns=data_frame.columns)
|
| 230 |
-
data_frame = new_df
|
| 231 |
-
|
| 232 |
-
except Exception as e:
|
| 233 |
-
raise gr.Error(f"DataFrame operation error: {str(e)}")
|
| 234 |
-
|
| 235 |
-
return output_file, gr.DataFrame(
|
| 236 |
-
label="Token Stats",
|
| 237 |
-
headers=["Source Text Token Count", "Expected Token Usage", "Token Used"],
|
| 238 |
-
datatype="str",
|
| 239 |
-
interactive=False,
|
| 240 |
-
value=data_frame,
|
| 241 |
-
visible=True,
|
| 242 |
-
wrap=True,
|
| 243 |
-
)
|
| 244 |
|
| 245 |
except Exception as e: # pylint: disable=broad-except
|
| 246 |
raise gr.Error(f"Error occurred: {str(e)}")
|
| 247 |
|
| 248 |
finally:
|
| 249 |
# Clean up workspace
|
| 250 |
-
cleanup_workspace(
|
| 251 |
|
| 252 |
|
| 253 |
with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
@@ -267,7 +231,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 267 |
("简体中文", "zh"),
|
| 268 |
],
|
| 269 |
value="en",
|
| 270 |
-
# label=_("Language"),
|
| 271 |
render=False,
|
| 272 |
container=False,
|
| 273 |
elem_classes=["center-row"],
|
|
@@ -295,7 +258,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 295 |
os.path.join(root_dir, "webui", "translation.json"),
|
| 296 |
lang_btn,
|
| 297 |
placeholder_langs=["en", "zh"],
|
| 298 |
-
persistant=False,
|
| 299 |
):
|
| 300 |
lang_btn.render()
|
| 301 |
|
|
@@ -701,7 +664,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 701 |
outputs=[output, token_counter],
|
| 702 |
)
|
| 703 |
|
| 704 |
-
|
| 705 |
if __name__ == "__main__":
|
| 706 |
demo.queue(api_open=False, default_concurrency_limit=2)
|
| 707 |
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)
|
|
|
|
| 5 |
from importlib.resources import files
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
+
import ray
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
|
| 11 |
+
from graphgen.engine import Engine
|
| 12 |
+
from graphgen.operators import operators
|
| 13 |
+
from graphgen.utils import CURRENT_LOGGER_VAR, set_logger
|
|
|
|
|
|
|
| 14 |
from webui.base import WebuiParams
|
| 15 |
from webui.i18n import Translate
|
| 16 |
from webui.i18n import gettext as _
|
|
|
|
| 20 |
root_dir = files("webui").parent
|
| 21 |
sys.path.append(root_dir)
|
| 22 |
|
|
|
|
| 23 |
load_dotenv()
|
| 24 |
|
| 25 |
css = """
|
|
|
|
| 31 |
"""
|
| 32 |
|
| 33 |
|
| 34 |
+
def _get_partition_params(params: WebuiParams):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
method = params.partition_method
|
| 36 |
if method == "dfs":
|
| 37 |
+
return {
|
| 38 |
"max_units_per_community": params.dfs_max_units,
|
| 39 |
}
|
| 40 |
+
if method == "bfs":
|
| 41 |
+
return {
|
| 42 |
"max_units_per_community": params.bfs_max_units,
|
| 43 |
}
|
| 44 |
+
if method == "leiden":
|
| 45 |
+
return {
|
| 46 |
"max_size": params.leiden_max_size,
|
| 47 |
"use_lcc": params.leiden_use_lcc,
|
| 48 |
"random_seed": params.leiden_random_seed,
|
| 49 |
}
|
| 50 |
+
# ece
|
| 51 |
+
return {
|
| 52 |
+
"max_units_per_community": params.ece_max_units,
|
| 53 |
+
"min_units_per_community": params.ece_min_units,
|
| 54 |
+
"max_tokens_per_community": params.ece_max_tokens,
|
| 55 |
+
"unit_sampling": params.ece_unit_sampling,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# pylint: disable=too-many-statements
|
| 60 |
+
def run_graphgen(params: WebuiParams, progress=gr.Progress()):
|
| 61 |
+
# 1. Setup Workspace
|
| 62 |
+
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
|
| 63 |
+
driver_logger = set_logger(log_file, "GraphGeb", if_stream=True)
|
| 64 |
+
CURRENT_LOGGER_VAR.set(driver_logger)
|
| 65 |
+
|
| 66 |
+
# 2. Setup Environment Variables for Ray Actors/LLM Init
|
| 67 |
+
# The refactored code relies on env vars in graphgen/common/init_llm.py
|
| 68 |
+
os.environ["SYNTHESIZER_BACKEND"] = "openai_api" # Assuming OpenAI compatible API
|
| 69 |
+
os.environ["SYNTHESIZER_BASE_URL"] = params.synthesizer_url
|
| 70 |
+
os.environ["SYNTHESIZER_API_KEY"] = params.api_key
|
| 71 |
+
os.environ["SYNTHESIZER_MODEL"] = params.synthesizer_model
|
| 72 |
+
os.environ["RPM"] = str(params.rpm)
|
| 73 |
+
os.environ["TPM"] = str(params.tpm)
|
| 74 |
+
os.environ["TOKENIZER_MODEL"] = params.tokenizer
|
| 75 |
+
|
| 76 |
+
if params.if_trainee_model:
|
| 77 |
+
os.environ["TRAINEE_BACKEND"] = "openai_api"
|
| 78 |
+
os.environ["TRAINEE_BASE_URL"] = params.trainee_url
|
| 79 |
+
os.environ["TRAINEE_API_KEY"] = params.trainee_api_key
|
| 80 |
+
os.environ["TRAINEE_MODEL"] = params.trainee_model
|
| 81 |
|
| 82 |
+
# 3. Construct Pipeline Configuration (DAG)
|
| 83 |
+
nodes = [
|
| 84 |
{
|
| 85 |
+
"id": "read",
|
| 86 |
+
"op_name": "read",
|
| 87 |
+
"type": "source",
|
| 88 |
+
"dependencies": [],
|
| 89 |
"params": {
|
| 90 |
+
"input_path": [params.upload_file],
|
| 91 |
},
|
| 92 |
},
|
| 93 |
{
|
| 94 |
+
"id": "chunk",
|
| 95 |
+
"op_name": "chunk",
|
| 96 |
+
"type": "map_batch",
|
| 97 |
+
"dependencies": ["read"],
|
| 98 |
+
"execution_params": {"replicas": 1},
|
| 99 |
"params": {
|
| 100 |
"chunk_size": params.chunk_size,
|
| 101 |
"chunk_overlap": params.chunk_overlap,
|
| 102 |
},
|
| 103 |
},
|
| 104 |
{
|
| 105 |
+
"id": "build_kg",
|
| 106 |
+
"op_name": "build_kg",
|
| 107 |
+
"type": "map_batch",
|
| 108 |
+
"dependencies": ["chunk"],
|
| 109 |
+
"execution_params": {"replicas": 1, "batch_size": 128},
|
| 110 |
},
|
| 111 |
]
|
| 112 |
|
| 113 |
+
last_node_id = "build_kg"
|
| 114 |
+
|
| 115 |
+
# Optional: Quiz and Judge
|
| 116 |
if params.if_trainee_model:
|
| 117 |
+
nodes.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
{
|
| 119 |
+
"id": "quiz",
|
| 120 |
+
"op_name": "quiz",
|
| 121 |
+
"type": "aggregate", # QuizService uses aggregate in config
|
| 122 |
+
"dependencies": ["build_kg"],
|
| 123 |
+
"execution_params": {"replicas": 1, "batch_size": 128},
|
| 124 |
"params": {
|
| 125 |
+
"quiz_samples": params.quiz_samples,
|
| 126 |
+
"concurrency_limit": 200,
|
| 127 |
},
|
| 128 |
}
|
| 129 |
)
|
| 130 |
+
|
| 131 |
+
nodes.append(
|
| 132 |
{
|
| 133 |
+
"id": "judge",
|
| 134 |
+
"op_name": "judge",
|
| 135 |
+
"type": "map_batch",
|
| 136 |
+
"dependencies": ["quiz"],
|
| 137 |
+
"execution_params": {"replicas": 1, "batch_size": 128},
|
|
|
|
|
|
|
| 138 |
}
|
| 139 |
)
|
| 140 |
+
last_node_id = "judge"
|
| 141 |
+
|
| 142 |
+
# Node: Partition
|
| 143 |
+
nodes.append(
|
| 144 |
+
{
|
| 145 |
+
"id": "partition",
|
| 146 |
+
"op_name": "partition",
|
| 147 |
+
"type": "aggregate", # PartitionService uses aggregate
|
| 148 |
+
"dependencies": [last_node_id],
|
| 149 |
+
"params": {
|
| 150 |
+
"method": params.partition_method,
|
| 151 |
+
"method_params": _get_partition_params(params),
|
| 152 |
+
},
|
| 153 |
+
}
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# Node: Generate
|
| 157 |
+
nodes.append(
|
| 158 |
{
|
| 159 |
+
"id": "generate",
|
| 160 |
+
"op_name": "generate",
|
| 161 |
+
"type": "map_batch",
|
| 162 |
+
"dependencies": ["partition"],
|
| 163 |
+
"execution_params": {"replicas": 1, "batch_size": 128},
|
| 164 |
"params": {
|
| 165 |
"method": params.mode,
|
| 166 |
"data_format": params.data_format,
|
|
|
|
| 168 |
}
|
| 169 |
)
|
| 170 |
|
| 171 |
+
config = {"global_params": {"working_dir": working_dir}, "nodes": nodes}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
+
try:
|
| 174 |
+
# 4. Initialize and Run Engine
|
| 175 |
+
# Initialize Ray if not already running (Engine handles this mostly, but good for safety)
|
| 176 |
+
if not ray.is_initialized():
|
| 177 |
+
ray.init(ignore_reinit_error=True, log_to_driver=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
+
engine = Engine(config, operators)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
+
# Start with an empty dataset to kick off the pipeline
|
| 182 |
+
ds = ray.data.from_items([])
|
|
|
|
|
|
|
| 183 |
|
| 184 |
+
# Execute pipeline
|
| 185 |
+
results = engine.execute(ds)
|
| 186 |
+
|
| 187 |
+
# 5. Process Output
|
| 188 |
+
# Extract the result from the 'generate' node
|
| 189 |
+
if "generate" in results:
|
| 190 |
+
result_ds = results["generate"]
|
| 191 |
+
|
| 192 |
+
# Create a temporary file to save the output
|
| 193 |
+
with tempfile.NamedTemporaryFile(
|
| 194 |
+
mode="w", suffix=".jsonl", delete=False, encoding="utf-8"
|
| 195 |
+
) as tmpfile:
|
| 196 |
+
# Iterate over rows and write to file
|
| 197 |
+
for row in result_ds.iter_rows():
|
| 198 |
+
json.dump(row, tmpfile, ensure_ascii=False)
|
| 199 |
+
tmpfile.write("\n")
|
| 200 |
+
output_file = tmpfile.name
|
| 201 |
+
else:
|
| 202 |
+
raise gr.Error("Generation step failed to produce output.")
|
| 203 |
+
|
| 204 |
+
# Note: Dynamic token counting from distributed actors is not directly available
|
| 205 |
+
# via client properties in the new architecture. We return the estimated stats from input.
|
| 206 |
+
|
| 207 |
+
return output_file, params.token_counter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
except Exception as e: # pylint: disable=broad-except
|
| 210 |
raise gr.Error(f"Error occurred: {str(e)}")
|
| 211 |
|
| 212 |
finally:
|
| 213 |
# Clean up workspace
|
| 214 |
+
cleanup_workspace(working_dir) # Optional: keep for debugging or enable
|
| 215 |
|
| 216 |
|
| 217 |
with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
|
|
| 231 |
("简体中文", "zh"),
|
| 232 |
],
|
| 233 |
value="en",
|
|
|
|
| 234 |
render=False,
|
| 235 |
container=False,
|
| 236 |
elem_classes=["center-row"],
|
|
|
|
| 258 |
os.path.join(root_dir, "webui", "translation.json"),
|
| 259 |
lang_btn,
|
| 260 |
placeholder_langs=["en", "zh"],
|
| 261 |
+
persistant=False,
|
| 262 |
):
|
| 263 |
lang_btn.render()
|
| 264 |
|
|
|
|
| 664 |
outputs=[output, token_counter],
|
| 665 |
)
|
| 666 |
|
|
|
|
| 667 |
if __name__ == "__main__":
|
| 668 |
demo.queue(api_open=False, default_concurrency_limit=2)
|
| 669 |
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)
|
graphgen/bases/__init__.py
CHANGED
|
@@ -2,15 +2,11 @@ from .base_extractor import BaseExtractor
|
|
| 2 |
from .base_generator import BaseGenerator
|
| 3 |
from .base_kg_builder import BaseKGBuilder
|
| 4 |
from .base_llm_wrapper import BaseLLMWrapper
|
|
|
|
| 5 |
from .base_partitioner import BasePartitioner
|
| 6 |
from .base_reader import BaseReader
|
| 7 |
from .base_searcher import BaseSearcher
|
| 8 |
from .base_splitter import BaseSplitter
|
| 9 |
-
from .base_storage import
|
| 10 |
-
BaseGraphStorage,
|
| 11 |
-
BaseKVStorage,
|
| 12 |
-
BaseListStorage,
|
| 13 |
-
StorageNameSpace,
|
| 14 |
-
)
|
| 15 |
from .base_tokenizer import BaseTokenizer
|
| 16 |
-
from .datatypes import Chunk, QAPair, Token
|
|
|
|
| 2 |
from .base_generator import BaseGenerator
|
| 3 |
from .base_kg_builder import BaseKGBuilder
|
| 4 |
from .base_llm_wrapper import BaseLLMWrapper
|
| 5 |
+
from .base_operator import BaseOperator
|
| 6 |
from .base_partitioner import BasePartitioner
|
| 7 |
from .base_reader import BaseReader
|
| 8 |
from .base_searcher import BaseSearcher
|
| 9 |
from .base_splitter import BaseSplitter
|
| 10 |
+
from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from .base_tokenizer import BaseTokenizer
|
| 12 |
+
from .datatypes import Chunk, Config, Node, QAPair, Token
|
graphgen/bases/base_llm_wrapper.py
CHANGED
|
@@ -72,9 +72,3 @@ class BaseLLMWrapper(abc.ABC):
|
|
| 72 |
|
| 73 |
filtered = filtered.strip()
|
| 74 |
return filtered if filtered else text.strip()
|
| 75 |
-
|
| 76 |
-
def shutdown(self) -> None:
|
| 77 |
-
"""Shutdown the LLM engine if applicable."""
|
| 78 |
-
|
| 79 |
-
def restart(self) -> None:
|
| 80 |
-
"""Reinitialize the LLM engine if applicable."""
|
|
|
|
| 72 |
|
| 73 |
filtered = filtered.strip()
|
| 74 |
return filtered if filtered else text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/bases/base_operator.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import os
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Iterable, Union
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import ray
|
| 8 |
+
|
| 9 |
+
from graphgen.utils import CURRENT_LOGGER_VAR, set_logger
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BaseOperator(ABC):
|
| 13 |
+
def __init__(self, working_dir: str = "cache", op_name: str = None):
|
| 14 |
+
log_dir = os.path.join(working_dir, "logs")
|
| 15 |
+
self.op_name = op_name or self.__class__.__name__
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
ctx = ray.get_runtime_context()
|
| 19 |
+
worker_id = ctx.get_actor_id() or ctx.get_worker_id()
|
| 20 |
+
worker_id_short = worker_id[-6:] if worker_id else "driver"
|
| 21 |
+
except Exception as e:
|
| 22 |
+
print(
|
| 23 |
+
"Warning: Could not get Ray worker ID, defaulting to 'local'. Exception:",
|
| 24 |
+
e,
|
| 25 |
+
)
|
| 26 |
+
worker_id_short = "local"
|
| 27 |
+
|
| 28 |
+
# e.g. cache/logs/ChunkService_a1b2c3.log
|
| 29 |
+
log_file = os.path.join(log_dir, f"{self.op_name}_{worker_id_short}.log")
|
| 30 |
+
|
| 31 |
+
self.logger = set_logger(
|
| 32 |
+
log_file=log_file, name=f"{self.op_name}.{worker_id_short}", force=True
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
self.logger.info(
|
| 36 |
+
"[%s] Operator initialized on Worker %s", self.op_name, worker_id_short
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def __call__(
|
| 40 |
+
self, batch: pd.DataFrame
|
| 41 |
+
) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]:
|
| 42 |
+
logger_token = CURRENT_LOGGER_VAR.set(self.logger)
|
| 43 |
+
try:
|
| 44 |
+
result = self.process(batch)
|
| 45 |
+
if inspect.isgenerator(result):
|
| 46 |
+
yield from result
|
| 47 |
+
else:
|
| 48 |
+
yield result
|
| 49 |
+
finally:
|
| 50 |
+
CURRENT_LOGGER_VAR.reset(logger_token)
|
| 51 |
+
|
| 52 |
+
@abstractmethod
|
| 53 |
+
def process(self, batch):
|
| 54 |
+
raise NotImplementedError("Subclasses must implement the process method.")
|
| 55 |
+
|
| 56 |
+
def get_logger(self):
|
| 57 |
+
return self.logger
|
graphgen/bases/base_partitioner.py
CHANGED
|
@@ -7,7 +7,7 @@ from graphgen.bases.datatypes import Community
|
|
| 7 |
|
| 8 |
class BasePartitioner(ABC):
|
| 9 |
@abstractmethod
|
| 10 |
-
|
| 11 |
self,
|
| 12 |
g: BaseGraphStorage,
|
| 13 |
**kwargs: Any,
|
|
@@ -20,39 +20,34 @@ class BasePartitioner(ABC):
|
|
| 20 |
"""
|
| 21 |
|
| 22 |
@staticmethod
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
) ->
|
| 26 |
-
tuple[
|
| 27 |
-
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
| 28 |
-
]
|
| 29 |
]:
|
| 30 |
"""
|
| 31 |
Convert communities to batches of nodes and edges.
|
| 32 |
-
:param
|
| 33 |
:param g: Graph storage instance
|
| 34 |
:return: List of batches, each batch is a tuple of (nodes, edges)
|
| 35 |
"""
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
| 48 |
if edge_data:
|
| 49 |
-
edges_data.append((
|
| 50 |
-
|
| 51 |
-
edge_data = g.get_edge(v, u)
|
| 52 |
-
if edge_data:
|
| 53 |
-
edges_data.append((v, u, edge_data))
|
| 54 |
-
batches.append((nodes_data, edges_data))
|
| 55 |
-
return batches
|
| 56 |
|
| 57 |
@staticmethod
|
| 58 |
def _build_adjacency_list(
|
|
|
|
| 7 |
|
| 8 |
class BasePartitioner(ABC):
|
| 9 |
@abstractmethod
|
| 10 |
+
def partition(
|
| 11 |
self,
|
| 12 |
g: BaseGraphStorage,
|
| 13 |
**kwargs: Any,
|
|
|
|
| 20 |
"""
|
| 21 |
|
| 22 |
@staticmethod
|
| 23 |
+
def community2batch(
|
| 24 |
+
comm: Community, g: BaseGraphStorage
|
| 25 |
+
) -> tuple[
|
| 26 |
+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
|
|
|
|
|
|
|
| 27 |
]:
|
| 28 |
"""
|
| 29 |
Convert communities to batches of nodes and edges.
|
| 30 |
+
:param comm: Community
|
| 31 |
:param g: Graph storage instance
|
| 32 |
:return: List of batches, each batch is a tuple of (nodes, edges)
|
| 33 |
"""
|
| 34 |
+
nodes = comm.nodes
|
| 35 |
+
edges = comm.edges
|
| 36 |
+
nodes_data = []
|
| 37 |
+
for node in nodes:
|
| 38 |
+
node_data = g.get_node(node)
|
| 39 |
+
if node_data:
|
| 40 |
+
nodes_data.append((node, node_data))
|
| 41 |
+
edges_data = []
|
| 42 |
+
for u, v in edges:
|
| 43 |
+
edge_data = g.get_edge(u, v)
|
| 44 |
+
if edge_data:
|
| 45 |
+
edges_data.append((u, v, edge_data))
|
| 46 |
+
else:
|
| 47 |
+
edge_data = g.get_edge(v, u)
|
| 48 |
if edge_data:
|
| 49 |
+
edges_data.append((v, u, edge_data))
|
| 50 |
+
return nodes_data, edges_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
@staticmethod
|
| 53 |
def _build_adjacency_list(
|
graphgen/bases/base_reader.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
import os
|
| 2 |
from abc import ABC, abstractmethod
|
| 3 |
-
from typing import Any, Dict, List
|
| 4 |
|
|
|
|
| 5 |
import requests
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class BaseReader(ABC):
|
|
@@ -10,56 +12,70 @@ class BaseReader(ABC):
|
|
| 10 |
Abstract base class for reading and processing data.
|
| 11 |
"""
|
| 12 |
|
| 13 |
-
def __init__(self, text_column: str = "content"):
|
| 14 |
self.text_column = text_column
|
|
|
|
| 15 |
|
| 16 |
@abstractmethod
|
| 17 |
-
def read(self,
|
| 18 |
"""
|
| 19 |
Read data from the specified file path.
|
| 20 |
|
| 21 |
-
:param
|
| 22 |
-
:return:
|
| 23 |
"""
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
"""
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
|
|
|
| 32 |
"""
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
:return: True if the image exists, False otherwise.
|
| 40 |
-
"""
|
| 41 |
-
if not path_or_url:
|
| 42 |
-
return False
|
| 43 |
-
if not path_or_url.startswith(("http://", "https://", "ftp://")):
|
| 44 |
-
path = path_or_url.replace("file://", "", 1)
|
| 45 |
-
path = os.path.abspath(path)
|
| 46 |
-
return os.path.isfile(path)
|
| 47 |
-
try:
|
| 48 |
-
resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
|
| 49 |
-
return resp.status_code == 200
|
| 50 |
-
except requests.RequestException:
|
| 51 |
-
return False
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from abc import ABC, abstractmethod
|
| 3 |
+
from typing import Any, Dict, List, Union
|
| 4 |
|
| 5 |
+
import pandas as pd
|
| 6 |
import requests
|
| 7 |
+
from ray.data import Dataset
|
| 8 |
|
| 9 |
|
| 10 |
class BaseReader(ABC):
|
|
|
|
| 12 |
Abstract base class for reading and processing data.
|
| 13 |
"""
|
| 14 |
|
| 15 |
+
def __init__(self, text_column: str = "content", modalities: list = None):
|
| 16 |
self.text_column = text_column
|
| 17 |
+
self.modalities = modalities if modalities is not None else ["text"]
|
| 18 |
|
| 19 |
@abstractmethod
|
| 20 |
+
def read(self, input_path: Union[str, List[str]]) -> Dataset:
|
| 21 |
"""
|
| 22 |
Read data from the specified file path.
|
| 23 |
|
| 24 |
+
:param input_path: Path to the input file or list of file paths.
|
| 25 |
+
:return: Ray Dataset containing the read data.
|
| 26 |
"""
|
| 27 |
|
| 28 |
+
def _should_keep_item(self, item: Dict[str, Any]) -> bool:
|
| 29 |
+
"""
|
| 30 |
+
Determine whether to keep the given item based on the text column.
|
| 31 |
+
|
| 32 |
+
:param item: Dictionary representing a data entry.
|
| 33 |
+
:return: True if the item should be kept, False otherwise.
|
| 34 |
"""
|
| 35 |
+
item_type = item.get("type")
|
| 36 |
+
assert item_type in [
|
| 37 |
+
"text",
|
| 38 |
+
"image",
|
| 39 |
+
"table",
|
| 40 |
+
"equation",
|
| 41 |
+
"protein",
|
| 42 |
+
], f"Unsupported item type: {item_type}"
|
| 43 |
+
if item_type == "text":
|
| 44 |
+
content = item.get(self.text_column, "").strip()
|
| 45 |
+
return bool(content)
|
| 46 |
+
return True
|
| 47 |
|
| 48 |
+
def _validate_batch(self, batch: pd.DataFrame) -> pd.DataFrame:
|
| 49 |
+
"""
|
| 50 |
+
Validate data format.
|
| 51 |
"""
|
| 52 |
+
if "type" not in batch.columns:
|
| 53 |
+
raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}")
|
| 54 |
|
| 55 |
+
if "text" in batch["type"].values:
|
| 56 |
+
if self.text_column not in batch.columns:
|
| 57 |
+
raise ValueError(
|
| 58 |
+
f"Missing '{self.text_column}' column for text documents"
|
| 59 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
return batch
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def _image_exists(path_or_url: str, timeout: int = 3) -> bool:
|
| 65 |
+
"""
|
| 66 |
+
Check if an image exists at the given local path or URL.
|
| 67 |
+
:param path_or_url: Local file path or remote URL of the image.
|
| 68 |
+
:param timeout: Timeout for remote URL requests in seconds.
|
| 69 |
+
:return: True if the image exists, False otherwise.
|
| 70 |
+
"""
|
| 71 |
+
if not path_or_url:
|
| 72 |
+
return False
|
| 73 |
+
if not path_or_url.startswith(("http://", "https://", "ftp://")):
|
| 74 |
+
path = path_or_url.replace("file://", "", 1)
|
| 75 |
+
path = os.path.abspath(path)
|
| 76 |
+
return os.path.isfile(path)
|
| 77 |
+
try:
|
| 78 |
+
resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout)
|
| 79 |
+
return resp.status_code == 200
|
| 80 |
+
except requests.RequestException:
|
| 81 |
+
return False
|
graphgen/bases/base_splitter.py
CHANGED
|
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
|
| 4 |
from typing import Callable, Iterable, List, Literal, Optional, Union
|
| 5 |
|
| 6 |
from graphgen.bases.datatypes import Chunk
|
| 7 |
-
from graphgen.utils import logger
|
| 8 |
|
| 9 |
|
| 10 |
class BaseSplitter(ABC):
|
|
@@ -33,7 +33,7 @@ class BaseSplitter(ABC):
|
|
| 33 |
"""
|
| 34 |
Split the input text into smaller chunks.
|
| 35 |
|
| 36 |
-
:param text: The input text to be
|
| 37 |
:return: A list of text chunks.
|
| 38 |
"""
|
| 39 |
|
|
@@ -111,7 +111,7 @@ class BaseSplitter(ABC):
|
|
| 111 |
def _split_text_with_regex(
|
| 112 |
text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
|
| 113 |
) -> List[str]:
|
| 114 |
-
# Now that we have the separator,
|
| 115 |
if separator:
|
| 116 |
if keep_separator:
|
| 117 |
# The parentheses in the pattern keep the delimiters in the result.
|
|
|
|
| 4 |
from typing import Callable, Iterable, List, Literal, Optional, Union
|
| 5 |
|
| 6 |
from graphgen.bases.datatypes import Chunk
|
| 7 |
+
from graphgen.utils.log import logger
|
| 8 |
|
| 9 |
|
| 10 |
class BaseSplitter(ABC):
|
|
|
|
| 33 |
"""
|
| 34 |
Split the input text into smaller chunks.
|
| 35 |
|
| 36 |
+
:param text: The input text to be chunk.
|
| 37 |
:return: A list of text chunks.
|
| 38 |
"""
|
| 39 |
|
|
|
|
| 111 |
def _split_text_with_regex(
|
| 112 |
text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
|
| 113 |
) -> List[str]:
|
| 114 |
+
# Now that we have the separator, chunk the text
|
| 115 |
if separator:
|
| 116 |
if keep_separator:
|
| 117 |
# The parentheses in the pattern keep the delimiters in the result.
|
graphgen/bases/base_storage.py
CHANGED
|
@@ -16,23 +16,6 @@ class StorageNameSpace:
|
|
| 16 |
"""commit the storage operations after querying"""
|
| 17 |
|
| 18 |
|
| 19 |
-
class BaseListStorage(Generic[T], StorageNameSpace):
|
| 20 |
-
def all_items(self) -> list[T]:
|
| 21 |
-
raise NotImplementedError
|
| 22 |
-
|
| 23 |
-
def get_by_index(self, index: int) -> Union[T, None]:
|
| 24 |
-
raise NotImplementedError
|
| 25 |
-
|
| 26 |
-
def append(self, data: T):
|
| 27 |
-
raise NotImplementedError
|
| 28 |
-
|
| 29 |
-
def upsert(self, data: list[T]):
|
| 30 |
-
raise NotImplementedError
|
| 31 |
-
|
| 32 |
-
def drop(self):
|
| 33 |
-
raise NotImplementedError
|
| 34 |
-
|
| 35 |
-
|
| 36 |
class BaseKVStorage(Generic[T], StorageNameSpace):
|
| 37 |
def all_keys(self) -> list[str]:
|
| 38 |
raise NotImplementedError
|
|
@@ -58,6 +41,9 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
|
|
| 58 |
def drop(self):
|
| 59 |
raise NotImplementedError
|
| 60 |
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
class BaseGraphStorage(StorageNameSpace):
|
| 63 |
def has_node(self, node_id: str) -> bool:
|
|
@@ -105,3 +91,6 @@ class BaseGraphStorage(StorageNameSpace):
|
|
| 105 |
|
| 106 |
def delete_node(self, node_id: str):
|
| 107 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"""commit the storage operations after querying"""
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class BaseKVStorage(Generic[T], StorageNameSpace):
|
| 20 |
def all_keys(self) -> list[str]:
|
| 21 |
raise NotImplementedError
|
|
|
|
| 41 |
def drop(self):
|
| 42 |
raise NotImplementedError
|
| 43 |
|
| 44 |
+
def reload(self):
|
| 45 |
+
raise NotImplementedError
|
| 46 |
+
|
| 47 |
|
| 48 |
class BaseGraphStorage(StorageNameSpace):
|
| 49 |
def has_node(self, node_id: str) -> bool:
|
|
|
|
| 91 |
|
| 92 |
def delete_node(self, node_id: str):
|
| 93 |
raise NotImplementedError
|
| 94 |
+
|
| 95 |
+
def reload(self):
|
| 96 |
+
raise NotImplementedError
|
graphgen/bases/datatypes.py
CHANGED
|
@@ -2,6 +2,8 @@ import math
|
|
| 2 |
from dataclasses import dataclass, field
|
| 3 |
from typing import List, Union
|
| 4 |
|
|
|
|
|
|
|
| 5 |
|
| 6 |
@dataclass
|
| 7 |
class Chunk:
|
|
@@ -48,3 +50,45 @@ class Community:
|
|
| 48 |
nodes: List[str] = field(default_factory=list)
|
| 49 |
edges: List[tuple] = field(default_factory=list)
|
| 50 |
metadata: dict = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from dataclasses import dataclass, field
|
| 3 |
from typing import List, Union
|
| 4 |
|
| 5 |
+
from pydantic import BaseModel, Field, field_validator
|
| 6 |
+
|
| 7 |
|
| 8 |
@dataclass
|
| 9 |
class Chunk:
|
|
|
|
| 50 |
nodes: List[str] = field(default_factory=list)
|
| 51 |
edges: List[tuple] = field(default_factory=list)
|
| 52 |
metadata: dict = field(default_factory=dict)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Node(BaseModel):
|
| 56 |
+
id: str = Field(..., description="unique node id")
|
| 57 |
+
op_name: str = Field(..., description="operator name")
|
| 58 |
+
type: str = Field(
|
| 59 |
+
..., description="task type, e.g., map, filter, flatmap, aggregate, map_batch"
|
| 60 |
+
)
|
| 61 |
+
params: dict = Field(default_factory=dict, description="operator parameters")
|
| 62 |
+
dependencies: List[str] = Field(
|
| 63 |
+
default_factory=list, description="list of dependent node ids"
|
| 64 |
+
)
|
| 65 |
+
execution_params: dict = Field(
|
| 66 |
+
default_factory=dict, description="execution parameters like replicas, batch_size"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
@classmethod
|
| 70 |
+
@field_validator("type")
|
| 71 |
+
def validate_type(cls, v: str) -> str:
|
| 72 |
+
valid_types = {"map", "filter", "flatmap", "aggregate", "map_batch"}
|
| 73 |
+
if v not in valid_types:
|
| 74 |
+
raise ValueError(f"Invalid node type: {v}. Must be one of {valid_types}.")
|
| 75 |
+
return v
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Config(BaseModel):
|
| 79 |
+
global_params: dict = Field(
|
| 80 |
+
default_factory=dict, description="global context for the computation graph"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
nodes: List[Node] = Field(
|
| 84 |
+
..., min_length=1, description="list of nodes in the computation graph"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
@classmethod
|
| 88 |
+
@field_validator("nodes")
|
| 89 |
+
def validate_unique_ids(cls, v: List[Node]) -> List[Node]:
|
| 90 |
+
ids = [node.id for node in v]
|
| 91 |
+
if len(ids) != len(set(ids)):
|
| 92 |
+
duplicates = {id_ for id_ in ids if ids.count(id_) > 1}
|
| 93 |
+
raise ValueError(f"Duplicate node ids found: {duplicates}")
|
| 94 |
+
return v
|
graphgen/{operators/init → common}/__init__.py
RENAMED
|
@@ -1 +1,2 @@
|
|
| 1 |
from .init_llm import init_llm
|
|
|
|
|
|
| 1 |
from .init_llm import init_llm
|
| 2 |
+
from .init_storage import init_storage
|
graphgen/{operators/init → common}/init_llm.py
RENAMED
|
@@ -1,56 +1,152 @@
|
|
| 1 |
import os
|
| 2 |
from typing import Any, Dict, Optional
|
| 3 |
|
|
|
|
|
|
|
| 4 |
from graphgen.bases import BaseLLMWrapper
|
|
|
|
| 5 |
from graphgen.models import Tokenizer
|
| 6 |
|
| 7 |
|
| 8 |
-
class
|
| 9 |
"""
|
| 10 |
-
A
|
| 11 |
-
Supported backends include:
|
| 12 |
-
- http_api: HTTPClient
|
| 13 |
-
- openai_api: OpenAIClient
|
| 14 |
-
- ollama_api: OllamaClient
|
| 15 |
-
- huggingface: HuggingFaceWrapper
|
| 16 |
-
- sglang: SGLangWrapper
|
| 17 |
"""
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
tokenizer
|
| 23 |
-
os.environ.get("TOKENIZER_MODEL", "cl100k_base"),
|
| 24 |
-
)
|
| 25 |
config["tokenizer"] = tokenizer
|
|
|
|
| 26 |
if backend == "http_api":
|
| 27 |
from graphgen.models.llm.api.http_client import HTTPClient
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
from graphgen.models.llm.api.openai_client import OpenAIClient
|
|
|
|
| 32 |
# pass in concrete backend to the OpenAIClient so that internally we can distinguish
|
| 33 |
# between OpenAI and Azure OpenAI
|
| 34 |
-
|
| 35 |
-
|
| 36 |
from graphgen.models.llm.api.ollama_client import OllamaClient
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
#
|
| 51 |
-
# return VLLMWrapper(**config)
|
| 52 |
|
| 53 |
-
|
| 54 |
|
| 55 |
|
| 56 |
def _load_env_group(prefix: str) -> Dict[str, Any]:
|
|
@@ -77,5 +173,5 @@ def init_llm(model_type: str) -> Optional[BaseLLMWrapper]:
|
|
| 77 |
if not config:
|
| 78 |
return None
|
| 79 |
backend = config.pop("backend")
|
| 80 |
-
llm_wrapper = LLMFactory.
|
| 81 |
return llm_wrapper
|
|
|
|
| 1 |
import os
|
| 2 |
from typing import Any, Dict, Optional
|
| 3 |
|
| 4 |
+
import ray
|
| 5 |
+
|
| 6 |
from graphgen.bases import BaseLLMWrapper
|
| 7 |
+
from graphgen.common.init_storage import get_actor_handle
|
| 8 |
from graphgen.models import Tokenizer
|
| 9 |
|
| 10 |
|
| 11 |
+
class LLMServiceActor:
|
| 12 |
"""
|
| 13 |
+
A Ray actor class to wrap LLM wrapper instances for distributed usage.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
+
def __init__(self, backend: str, config: Dict[str, Any]):
|
| 17 |
+
self.backend = backend
|
| 18 |
+
tokenizer_model = os.environ.get("TOKENIZER_MODEL", "cl100k_base")
|
| 19 |
+
tokenizer = Tokenizer(model_name=tokenizer_model)
|
|
|
|
|
|
|
| 20 |
config["tokenizer"] = tokenizer
|
| 21 |
+
|
| 22 |
if backend == "http_api":
|
| 23 |
from graphgen.models.llm.api.http_client import HTTPClient
|
| 24 |
|
| 25 |
+
self.llm_instance = HTTPClient(**config)
|
| 26 |
+
elif backend in ("openai_api", "azure_openai_api"):
|
| 27 |
from graphgen.models.llm.api.openai_client import OpenAIClient
|
| 28 |
+
|
| 29 |
# pass in concrete backend to the OpenAIClient so that internally we can distinguish
|
| 30 |
# between OpenAI and Azure OpenAI
|
| 31 |
+
self.llm_instance = OpenAIClient(**config, backend=backend)
|
| 32 |
+
elif backend == "ollama_api":
|
| 33 |
from graphgen.models.llm.api.ollama_client import OllamaClient
|
| 34 |
|
| 35 |
+
self.llm_instance = OllamaClient(**config)
|
| 36 |
+
elif backend == "huggingface":
|
| 37 |
from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper
|
| 38 |
|
| 39 |
+
self.llm_instance = HuggingFaceWrapper(**config)
|
| 40 |
+
elif backend == "sglang":
|
| 41 |
from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper
|
| 42 |
|
| 43 |
+
self.llm_instance = SGLangWrapper(**config)
|
| 44 |
+
|
| 45 |
+
elif backend == "vllm":
|
| 46 |
+
from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper
|
| 47 |
+
|
| 48 |
+
self.llm_instance = VLLMWrapper(**config)
|
| 49 |
+
else:
|
| 50 |
+
raise NotImplementedError(f"Backend {backend} is not implemented yet.")
|
| 51 |
+
|
| 52 |
+
async def generate_answer(
|
| 53 |
+
self, text: str, history: Optional[list[str]] = None, **extra: Any
|
| 54 |
+
) -> str:
|
| 55 |
+
return await self.llm_instance.generate_answer(text, history, **extra)
|
| 56 |
+
|
| 57 |
+
async def generate_topk_per_token(
|
| 58 |
+
self, text: str, history: Optional[list[str]] = None, **extra: Any
|
| 59 |
+
) -> list:
|
| 60 |
+
return await self.llm_instance.generate_topk_per_token(text, history, **extra)
|
| 61 |
+
|
| 62 |
+
async def generate_inputs_prob(
|
| 63 |
+
self, text: str, history: Optional[list[str]] = None, **extra: Any
|
| 64 |
+
) -> list:
|
| 65 |
+
return await self.llm_instance.generate_inputs_prob(text, history, **extra)
|
| 66 |
+
|
| 67 |
+
def ready(self) -> bool:
|
| 68 |
+
"""A simple method to check if the actor is ready."""
|
| 69 |
+
return True
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LLMServiceProxy(BaseLLMWrapper):
|
| 73 |
+
"""
|
| 74 |
+
A proxy class to interact with the LLMServiceActor for distributed LLM operations.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, actor_name: str):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.actor_handle = get_actor_handle(actor_name)
|
| 80 |
+
self._create_local_tokenizer()
|
| 81 |
+
|
| 82 |
+
async def generate_answer(
|
| 83 |
+
self, text: str, history: Optional[list[str]] = None, **extra: Any
|
| 84 |
+
) -> str:
|
| 85 |
+
object_ref = self.actor_handle.generate_answer.remote(text, history, **extra)
|
| 86 |
+
return await object_ref
|
| 87 |
+
|
| 88 |
+
async def generate_topk_per_token(
|
| 89 |
+
self, text: str, history: Optional[list[str]] = None, **extra: Any
|
| 90 |
+
) -> list:
|
| 91 |
+
object_ref = self.actor_handle.generate_topk_per_token.remote(
|
| 92 |
+
text, history, **extra
|
| 93 |
+
)
|
| 94 |
+
return await object_ref
|
| 95 |
+
|
| 96 |
+
async def generate_inputs_prob(
|
| 97 |
+
self, text: str, history: Optional[list[str]] = None, **extra: Any
|
| 98 |
+
) -> list:
|
| 99 |
+
object_ref = self.actor_handle.generate_inputs_prob.remote(
|
| 100 |
+
text, history, **extra
|
| 101 |
+
)
|
| 102 |
+
return await object_ref
|
| 103 |
+
|
| 104 |
+
def _create_local_tokenizer(self):
|
| 105 |
+
tokenizer_model = os.environ.get("TOKENIZER_MODEL", "cl100k_base")
|
| 106 |
+
self.tokenizer = Tokenizer(model_name=tokenizer_model)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class LLMFactory:
|
| 110 |
+
"""
|
| 111 |
+
A factory class to create LLM wrapper instances based on the specified backend.
|
| 112 |
+
Supported backends include:
|
| 113 |
+
- http_api: HTTPClient
|
| 114 |
+
- openai_api: OpenAIClient
|
| 115 |
+
- ollama_api: OllamaClient
|
| 116 |
+
- huggingface: HuggingFaceWrapper
|
| 117 |
+
- sglang: SGLangWrapper
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
@staticmethod
|
| 121 |
+
def create_llm(
|
| 122 |
+
model_type: str, backend: str, config: Dict[str, Any]
|
| 123 |
+
) -> BaseLLMWrapper:
|
| 124 |
+
if not config:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"No configuration provided for LLM {model_type} with backend {backend}."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
actor_name = f"Actor_LLM_{model_type}"
|
| 130 |
+
try:
|
| 131 |
+
ray.get_actor(actor_name)
|
| 132 |
+
except ValueError:
|
| 133 |
+
print(f"Creating Ray actor for LLM {model_type} with backend {backend}.")
|
| 134 |
+
num_gpus = int(config.pop("num_gpus", 0))
|
| 135 |
+
actor = (
|
| 136 |
+
ray.remote(LLMServiceActor)
|
| 137 |
+
.options(
|
| 138 |
+
name=actor_name,
|
| 139 |
+
num_gpus=num_gpus,
|
| 140 |
+
lifetime="detached",
|
| 141 |
+
get_if_exists=True,
|
| 142 |
+
)
|
| 143 |
+
.remote(backend, config)
|
| 144 |
+
)
|
| 145 |
|
| 146 |
+
# wait for actor to be ready
|
| 147 |
+
ray.get(actor.ready.remote())
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
return LLMServiceProxy(actor_name)
|
| 150 |
|
| 151 |
|
| 152 |
def _load_env_group(prefix: str) -> Dict[str, Any]:
|
|
|
|
| 173 |
if not config:
|
| 174 |
return None
|
| 175 |
backend = config.pop("backend")
|
| 176 |
+
llm_wrapper = LLMFactory.create_llm(model_type, backend, config)
|
| 177 |
return llm_wrapper
|
graphgen/common/init_storage.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Union
|
| 2 |
+
|
| 3 |
+
import ray
|
| 4 |
+
|
| 5 |
+
from graphgen.bases.base_storage import BaseGraphStorage, BaseKVStorage
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class KVStorageActor:
|
| 9 |
+
def __init__(self, backend: str, working_dir: str, namespace: str):
|
| 10 |
+
if backend == "json_kv":
|
| 11 |
+
from graphgen.models import JsonKVStorage
|
| 12 |
+
|
| 13 |
+
self.kv = JsonKVStorage(working_dir, namespace)
|
| 14 |
+
elif backend == "rocksdb":
|
| 15 |
+
from graphgen.models import RocksDBKVStorage
|
| 16 |
+
|
| 17 |
+
self.kv = RocksDBKVStorage(working_dir, namespace)
|
| 18 |
+
else:
|
| 19 |
+
raise ValueError(f"Unknown KV backend: {backend}")
|
| 20 |
+
|
| 21 |
+
def data(self) -> Dict[str, Dict]:
|
| 22 |
+
return self.kv.data
|
| 23 |
+
|
| 24 |
+
def all_keys(self) -> list[str]:
|
| 25 |
+
return self.kv.all_keys()
|
| 26 |
+
|
| 27 |
+
def index_done_callback(self):
|
| 28 |
+
return self.kv.index_done_callback()
|
| 29 |
+
|
| 30 |
+
def get_by_id(self, id: str) -> Dict:
|
| 31 |
+
return self.kv.get_by_id(id)
|
| 32 |
+
|
| 33 |
+
def get_by_ids(self, ids: list[str], fields=None) -> list:
|
| 34 |
+
return self.kv.get_by_ids(ids, fields)
|
| 35 |
+
|
| 36 |
+
def get_all(self) -> Dict[str, Dict]:
|
| 37 |
+
return self.kv.get_all()
|
| 38 |
+
|
| 39 |
+
def filter_keys(self, data: list[str]) -> set[str]:
|
| 40 |
+
return self.kv.filter_keys(data)
|
| 41 |
+
|
| 42 |
+
def upsert(self, data: dict) -> dict:
|
| 43 |
+
return self.kv.upsert(data)
|
| 44 |
+
|
| 45 |
+
def drop(self):
|
| 46 |
+
return self.kv.drop()
|
| 47 |
+
|
| 48 |
+
def reload(self):
|
| 49 |
+
return self.kv.reload()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class GraphStorageActor:
|
| 53 |
+
def __init__(self, backend: str, working_dir: str, namespace: str):
|
| 54 |
+
if backend == "networkx":
|
| 55 |
+
from graphgen.models import NetworkXStorage
|
| 56 |
+
|
| 57 |
+
self.graph = NetworkXStorage(working_dir, namespace)
|
| 58 |
+
elif backend == "kuzu":
|
| 59 |
+
from graphgen.models import KuzuStorage
|
| 60 |
+
|
| 61 |
+
self.graph = KuzuStorage(working_dir, namespace)
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError(f"Unknown Graph backend: {backend}")
|
| 64 |
+
|
| 65 |
+
def index_done_callback(self):
|
| 66 |
+
return self.graph.index_done_callback()
|
| 67 |
+
|
| 68 |
+
def has_node(self, node_id: str) -> bool:
|
| 69 |
+
return self.graph.has_node(node_id)
|
| 70 |
+
|
| 71 |
+
def has_edge(self, source_node_id: str, target_node_id: str):
|
| 72 |
+
return self.graph.has_edge(source_node_id, target_node_id)
|
| 73 |
+
|
| 74 |
+
def node_degree(self, node_id: str) -> int:
|
| 75 |
+
return self.graph.node_degree(node_id)
|
| 76 |
+
|
| 77 |
+
def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 78 |
+
return self.graph.edge_degree(src_id, tgt_id)
|
| 79 |
+
|
| 80 |
+
def get_node(self, node_id: str) -> Any:
|
| 81 |
+
return self.graph.get_node(node_id)
|
| 82 |
+
|
| 83 |
+
def update_node(self, node_id: str, node_data: dict[str, str]):
|
| 84 |
+
return self.graph.update_node(node_id, node_data)
|
| 85 |
+
|
| 86 |
+
def get_all_nodes(self) -> Any:
|
| 87 |
+
return self.graph.get_all_nodes()
|
| 88 |
+
|
| 89 |
+
def get_edge(self, source_node_id: str, target_node_id: str):
|
| 90 |
+
return self.graph.get_edge(source_node_id, target_node_id)
|
| 91 |
+
|
| 92 |
+
def update_edge(
|
| 93 |
+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 94 |
+
):
|
| 95 |
+
return self.graph.update_edge(source_node_id, target_node_id, edge_data)
|
| 96 |
+
|
| 97 |
+
def get_all_edges(self) -> Any:
|
| 98 |
+
return self.graph.get_all_edges()
|
| 99 |
+
|
| 100 |
+
def get_node_edges(self, source_node_id: str) -> Any:
|
| 101 |
+
return self.graph.get_node_edges(source_node_id)
|
| 102 |
+
|
| 103 |
+
def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 104 |
+
return self.graph.upsert_node(node_id, node_data)
|
| 105 |
+
|
| 106 |
+
def upsert_edge(
|
| 107 |
+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 108 |
+
):
|
| 109 |
+
return self.graph.upsert_edge(source_node_id, target_node_id, edge_data)
|
| 110 |
+
|
| 111 |
+
def delete_node(self, node_id: str):
|
| 112 |
+
return self.graph.delete_node(node_id)
|
| 113 |
+
|
| 114 |
+
def reload(self):
|
| 115 |
+
return self.graph.reload()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def get_actor_handle(name: str):
|
| 119 |
+
try:
|
| 120 |
+
return ray.get_actor(name)
|
| 121 |
+
except ValueError as exc:
|
| 122 |
+
raise RuntimeError(
|
| 123 |
+
f"Actor {name} not found. Make sure it is created before accessing."
|
| 124 |
+
) from exc
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class RemoteKVStorageProxy(BaseKVStorage):
|
| 128 |
+
def __init__(self, namespace: str):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.namespace = namespace
|
| 131 |
+
self.actor_name = f"Actor_KV_{namespace}"
|
| 132 |
+
self.actor = get_actor_handle(self.actor_name)
|
| 133 |
+
|
| 134 |
+
def data(self) -> Dict[str, Any]:
|
| 135 |
+
return ray.get(self.actor.data.remote())
|
| 136 |
+
|
| 137 |
+
def all_keys(self) -> list[str]:
|
| 138 |
+
return ray.get(self.actor.all_keys.remote())
|
| 139 |
+
|
| 140 |
+
def index_done_callback(self):
|
| 141 |
+
return ray.get(self.actor.index_done_callback.remote())
|
| 142 |
+
|
| 143 |
+
def get_by_id(self, id: str) -> Union[Any, None]:
|
| 144 |
+
return ray.get(self.actor.get_by_id.remote(id))
|
| 145 |
+
|
| 146 |
+
def get_by_ids(self, ids: list[str], fields=None) -> list[Any]:
|
| 147 |
+
return ray.get(self.actor.get_by_ids.remote(ids, fields))
|
| 148 |
+
|
| 149 |
+
def get_all(self) -> Dict[str, Any]:
|
| 150 |
+
return ray.get(self.actor.get_all.remote())
|
| 151 |
+
|
| 152 |
+
def filter_keys(self, data: list[str]) -> set[str]:
|
| 153 |
+
return ray.get(self.actor.filter_keys.remote(data))
|
| 154 |
+
|
| 155 |
+
def upsert(self, data: Dict[str, Any]):
|
| 156 |
+
return ray.get(self.actor.upsert.remote(data))
|
| 157 |
+
|
| 158 |
+
def drop(self):
|
| 159 |
+
return ray.get(self.actor.drop.remote())
|
| 160 |
+
|
| 161 |
+
def reload(self):
|
| 162 |
+
return ray.get(self.actor.reload.remote())
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class RemoteGraphStorageProxy(BaseGraphStorage):
|
| 166 |
+
def __init__(self, namespace: str):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.namespace = namespace
|
| 169 |
+
self.actor_name = f"Actor_Graph_{namespace}"
|
| 170 |
+
self.actor = get_actor_handle(self.actor_name)
|
| 171 |
+
|
| 172 |
+
def index_done_callback(self):
|
| 173 |
+
return ray.get(self.actor.index_done_callback.remote())
|
| 174 |
+
|
| 175 |
+
def has_node(self, node_id: str) -> bool:
|
| 176 |
+
return ray.get(self.actor.has_node.remote(node_id))
|
| 177 |
+
|
| 178 |
+
def has_edge(self, source_node_id: str, target_node_id: str):
|
| 179 |
+
return ray.get(self.actor.has_edge.remote(source_node_id, target_node_id))
|
| 180 |
+
|
| 181 |
+
def node_degree(self, node_id: str) -> int:
|
| 182 |
+
return ray.get(self.actor.node_degree.remote(node_id))
|
| 183 |
+
|
| 184 |
+
def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 185 |
+
return ray.get(self.actor.edge_degree.remote(src_id, tgt_id))
|
| 186 |
+
|
| 187 |
+
def get_node(self, node_id: str) -> Any:
|
| 188 |
+
return ray.get(self.actor.get_node.remote(node_id))
|
| 189 |
+
|
| 190 |
+
def update_node(self, node_id: str, node_data: dict[str, str]):
|
| 191 |
+
return ray.get(self.actor.update_node.remote(node_id, node_data))
|
| 192 |
+
|
| 193 |
+
def get_all_nodes(self) -> Any:
|
| 194 |
+
return ray.get(self.actor.get_all_nodes.remote())
|
| 195 |
+
|
| 196 |
+
def get_edge(self, source_node_id: str, target_node_id: str):
|
| 197 |
+
return ray.get(self.actor.get_edge.remote(source_node_id, target_node_id))
|
| 198 |
+
|
| 199 |
+
def update_edge(
|
| 200 |
+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 201 |
+
):
|
| 202 |
+
return ray.get(
|
| 203 |
+
self.actor.update_edge.remote(source_node_id, target_node_id, edge_data)
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def get_all_edges(self) -> Any:
|
| 207 |
+
return ray.get(self.actor.get_all_edges.remote())
|
| 208 |
+
|
| 209 |
+
def get_node_edges(self, source_node_id: str) -> Any:
|
| 210 |
+
return ray.get(self.actor.get_node_edges.remote(source_node_id))
|
| 211 |
+
|
| 212 |
+
def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 213 |
+
return ray.get(self.actor.upsert_node.remote(node_id, node_data))
|
| 214 |
+
|
| 215 |
+
def upsert_edge(
|
| 216 |
+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 217 |
+
):
|
| 218 |
+
return ray.get(
|
| 219 |
+
self.actor.upsert_edge.remote(source_node_id, target_node_id, edge_data)
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def delete_node(self, node_id: str):
|
| 223 |
+
return ray.get(self.actor.delete_node.remote(node_id))
|
| 224 |
+
|
| 225 |
+
def reload(self):
|
| 226 |
+
return ray.get(self.actor.reload.remote())
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class StorageFactory:
|
| 230 |
+
"""
|
| 231 |
+
Factory class to create storage instances based on backend.
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
@staticmethod
|
| 235 |
+
def create_storage(backend: str, working_dir: str, namespace: str):
|
| 236 |
+
if backend in ["json_kv", "rocksdb"]:
|
| 237 |
+
actor_name = f"Actor_KV_{namespace}"
|
| 238 |
+
try:
|
| 239 |
+
ray.get_actor(actor_name)
|
| 240 |
+
except ValueError:
|
| 241 |
+
ray.remote(KVStorageActor).options(
|
| 242 |
+
name=actor_name,
|
| 243 |
+
lifetime="detached",
|
| 244 |
+
get_if_exists=True,
|
| 245 |
+
).remote(backend, working_dir, namespace)
|
| 246 |
+
return RemoteKVStorageProxy(namespace)
|
| 247 |
+
if backend in ["networkx", "kuzu"]:
|
| 248 |
+
actor_name = f"Actor_Graph_{namespace}"
|
| 249 |
+
try:
|
| 250 |
+
ray.get_actor(actor_name)
|
| 251 |
+
except ValueError:
|
| 252 |
+
ray.remote(GraphStorageActor).options(
|
| 253 |
+
name=actor_name,
|
| 254 |
+
lifetime="detached",
|
| 255 |
+
get_if_exists=True,
|
| 256 |
+
).remote(backend, working_dir, namespace)
|
| 257 |
+
return RemoteGraphStorageProxy(namespace)
|
| 258 |
+
raise ValueError(f"Unknown storage backend: {backend}")
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def init_storage(backend: str, working_dir: str, namespace: str):
|
| 262 |
+
return StorageFactory.create_storage(backend, working_dir, namespace)
|
graphgen/configs/aggregated_config.yaml
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
pipeline:
|
| 2 |
-
- name: read_step # step name is unique in the pipeline, and can be referenced by other steps
|
| 3 |
-
op_key: read
|
| 4 |
-
params:
|
| 5 |
-
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
|
| 6 |
-
|
| 7 |
-
- name: chunk_step
|
| 8 |
-
op_key: chunk
|
| 9 |
-
deps: [read_step] # chunk_step depends on read_step
|
| 10 |
-
params:
|
| 11 |
-
chunk_size: 1024 # chunk size for text splitting
|
| 12 |
-
chunk_overlap: 100 # chunk overlap for text splitting
|
| 13 |
-
|
| 14 |
-
- name: build_kg_step
|
| 15 |
-
op_key: build_kg
|
| 16 |
-
deps: [chunk_step] # build_kg_step depends on chunk_step
|
| 17 |
-
|
| 18 |
-
- name: quiz_and_judge_step
|
| 19 |
-
op_key: quiz_and_judge
|
| 20 |
-
deps: [build_kg_step] # quiz_and_judge depends on build_kg_step
|
| 21 |
-
params:
|
| 22 |
-
quiz_samples: 2 # number of quiz samples to generate
|
| 23 |
-
re_judge: false # whether to re-judge the existing quiz samples
|
| 24 |
-
|
| 25 |
-
- name: partition_step
|
| 26 |
-
op_key: partition
|
| 27 |
-
deps: [quiz_and_judge_step] # partition_step depends on quiz_and_judge_step
|
| 28 |
-
params:
|
| 29 |
-
method: ece # ece is a custom partition method based on comprehension loss
|
| 30 |
-
method_params:
|
| 31 |
-
max_units_per_community: 20 # max nodes and edges per community
|
| 32 |
-
min_units_per_community: 5 # min nodes and edges per community
|
| 33 |
-
max_tokens_per_community: 10240 # max tokens per community
|
| 34 |
-
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
|
| 35 |
-
|
| 36 |
-
- name: generate_step
|
| 37 |
-
op_key: generate
|
| 38 |
-
deps: [partition_step] # generate_step depends on partition_step
|
| 39 |
-
params:
|
| 40 |
-
method: aggregated # atomic, aggregated, multi_hop, cot, vqa
|
| 41 |
-
data_format: ChatML # Alpaca, Sharegpt, ChatML
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/configs/atomic_config.yaml
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
pipeline:
|
| 2 |
-
- name: read_step
|
| 3 |
-
op_key: read
|
| 4 |
-
params:
|
| 5 |
-
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples
|
| 6 |
-
|
| 7 |
-
- name: chunk_step
|
| 8 |
-
op_key: chunk
|
| 9 |
-
deps: [read_step] # chunk_step depends on read_step
|
| 10 |
-
params:
|
| 11 |
-
chunk_size: 1024 # chunk size for text splitting
|
| 12 |
-
chunk_overlap: 100 # chunk overlap for text splitting
|
| 13 |
-
|
| 14 |
-
- name: build_kg_step
|
| 15 |
-
op_key: build_kg
|
| 16 |
-
deps: [chunk_step] # build_kg depends on chunk_step
|
| 17 |
-
|
| 18 |
-
- name: partition_step
|
| 19 |
-
op_key: partition
|
| 20 |
-
deps: [build_kg] # partition_step depends on build_kg
|
| 21 |
-
params:
|
| 22 |
-
method: dfs # partition method, support: dfs, bfs, ece, leiden
|
| 23 |
-
method_params:
|
| 24 |
-
max_units_per_community: 1 # atomic partition, one node or edge per community
|
| 25 |
-
|
| 26 |
-
- name: generate_step
|
| 27 |
-
op_key: generate
|
| 28 |
-
deps: [partition_step] # generate_step depends on partition_step
|
| 29 |
-
params:
|
| 30 |
-
method: atomic # atomic, aggregated, multi_hop, cot, vqa
|
| 31 |
-
data_format: Alpaca # Alpaca, Sharegpt, ChatML
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/configs/cot_config.yaml
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
pipeline:
|
| 2 |
-
- name: read_step
|
| 3 |
-
op_key: read
|
| 4 |
-
params:
|
| 5 |
-
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
|
| 6 |
-
|
| 7 |
-
- name: chunk_step
|
| 8 |
-
op_key: chunk
|
| 9 |
-
deps: [read_step] # chunk_step depends on read_step
|
| 10 |
-
params:
|
| 11 |
-
chunk_size: 1024 # chunk size for text splitting
|
| 12 |
-
chunk_overlap: 100 # chunk overlap for text splitting
|
| 13 |
-
|
| 14 |
-
- name: build_kg_step
|
| 15 |
-
op_key: build_kg
|
| 16 |
-
deps: [chunk_step] # build_kg depends on chunk_step
|
| 17 |
-
|
| 18 |
-
- name: partition_step
|
| 19 |
-
op_key: partition
|
| 20 |
-
deps: [build_kg_step] # partition_step depends on build_kg
|
| 21 |
-
params:
|
| 22 |
-
method: leiden # leiden is a partitioner detection algorithm
|
| 23 |
-
method_params:
|
| 24 |
-
max_size: 20 # Maximum size of communities
|
| 25 |
-
use_lcc: false # whether to use the largest connected component
|
| 26 |
-
random_seed: 42 # random seed for partitioning
|
| 27 |
-
|
| 28 |
-
- name: generate_step
|
| 29 |
-
op_key: generate
|
| 30 |
-
deps: [partition_step] # generate_step depends on partition_step
|
| 31 |
-
params:
|
| 32 |
-
method: cot # atomic, aggregated, multi_hop, cot, vqa
|
| 33 |
-
data_format: Sharegpt # Alpaca, Sharegpt, ChatML
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/configs/multi_hop_config.yaml
DELETED
|
@@ -1,34 +0,0 @@
|
|
| 1 |
-
pipeline:
|
| 2 |
-
- name: read_step
|
| 3 |
-
op_key: read
|
| 4 |
-
params:
|
| 5 |
-
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
|
| 6 |
-
|
| 7 |
-
- name: chunk_step
|
| 8 |
-
op_key: chunk
|
| 9 |
-
deps: [read_step] # chunk_step depends on read_step
|
| 10 |
-
params:
|
| 11 |
-
chunk_size: 1024 # chunk size for text splitting
|
| 12 |
-
chunk_overlap: 100 # chunk overlap for text splitting
|
| 13 |
-
|
| 14 |
-
- name: build_kg_step
|
| 15 |
-
op_key: build_kg
|
| 16 |
-
deps: [chunk_step] # build_kg_step depends on chunk_step
|
| 17 |
-
|
| 18 |
-
- name: partition_step
|
| 19 |
-
op_key: partition
|
| 20 |
-
deps: [build_kg_step] # partition_step depends on build_kg_step
|
| 21 |
-
params:
|
| 22 |
-
method: ece # ece is a custom partition method based on comprehension loss
|
| 23 |
-
method_params:
|
| 24 |
-
max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3
|
| 25 |
-
min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3
|
| 26 |
-
max_tokens_per_community: 10240 # max tokens per community
|
| 27 |
-
unit_sampling: random # unit sampling strategy, support: random, max_loss, min_loss
|
| 28 |
-
|
| 29 |
-
- name: generate_step
|
| 30 |
-
op_key: generate
|
| 31 |
-
deps: [partition_step] # generate_step depends on partition_step
|
| 32 |
-
params:
|
| 33 |
-
method: multi_hop # atomic, aggregated, multi_hop, cot, vqa
|
| 34 |
-
data_format: ChatML # Alpaca, Sharegpt, ChatML
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/configs/schema_guided_extraction_config.yaml
DELETED
|
@@ -1,20 +0,0 @@
|
|
| 1 |
-
pipeline:
|
| 2 |
-
- name: read_step
|
| 3 |
-
op_key: read
|
| 4 |
-
params:
|
| 5 |
-
input_file: resources/input_examples/extract_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
|
| 6 |
-
|
| 7 |
-
- name: chunk_step
|
| 8 |
-
op_key: chunk
|
| 9 |
-
deps: [read_step] # chunk_step depends on read_step
|
| 10 |
-
params:
|
| 11 |
-
chunk_size: 20480
|
| 12 |
-
chunk_overlap: 2000
|
| 13 |
-
separators: []
|
| 14 |
-
|
| 15 |
-
- name: extract_step
|
| 16 |
-
op_key: extract
|
| 17 |
-
deps: [chunk_step] # extract_step depends on chunk_step
|
| 18 |
-
params:
|
| 19 |
-
method: schema_guided # extraction method, support: schema_guided
|
| 20 |
-
schema_file: graphgen/templates/extraction/schemas/legal_contract.json # schema file path for schema_guided method
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/configs/search_dna_config.yaml
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
pipeline:
|
| 2 |
-
- name: read_step
|
| 3 |
-
op_key: read
|
| 4 |
-
params:
|
| 5 |
-
input_file: resources/input_examples/search_dna_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
|
| 6 |
-
|
| 7 |
-
- name: search_step
|
| 8 |
-
op_key: search
|
| 9 |
-
deps: [read_step] # search_step depends on read_step
|
| 10 |
-
params:
|
| 11 |
-
data_sources: [ncbi] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
|
| 12 |
-
ncbi_params:
|
| 13 |
-
email: test@example.com # NCBI requires an email address
|
| 14 |
-
tool: GraphGen # tool name for NCBI API
|
| 15 |
-
use_local_blast: true # whether to use local blast for DNA search
|
| 16 |
-
local_blast_db: refseq_release/refseq_release # path to local BLAST database (without .nhr extension)
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/configs/search_protein_config.yaml
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
pipeline:
|
| 2 |
-
- name: read_step
|
| 3 |
-
op_key: read
|
| 4 |
-
params:
|
| 5 |
-
input_file: resources/input_examples/search_protein_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
|
| 6 |
-
|
| 7 |
-
- name: search_step
|
| 8 |
-
op_key: search
|
| 9 |
-
deps: [read_step] # search_step depends on read_step
|
| 10 |
-
params:
|
| 11 |
-
data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot
|
| 12 |
-
uniprot_params:
|
| 13 |
-
use_local_blast: true # whether to use local blast for uniprot search
|
| 14 |
-
local_blast_db: /your_path/2024_01/uniprot_sprot # format: /path/to/${RELEASE}/uniprot_sprot
|
| 15 |
-
# options: uniprot_sprot (recommended, high quality), uniprot_trembl, or uniprot_${RELEASE} (merged database)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/configs/search_rna_config.yaml
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
pipeline:
|
| 2 |
-
- name: read_step
|
| 3 |
-
op_key: read
|
| 4 |
-
params:
|
| 5 |
-
input_file: resources/input_examples/search_rna_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
|
| 6 |
-
|
| 7 |
-
- name: search_step
|
| 8 |
-
op_key: search
|
| 9 |
-
deps: [read_step] # search_step depends on read_step
|
| 10 |
-
params:
|
| 11 |
-
data_sources: [rnacentral] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral
|
| 12 |
-
rnacentral_params:
|
| 13 |
-
use_local_blast: true # whether to use local blast for RNA search
|
| 14 |
-
local_blast_db: rnacentral_ensembl_gencode_YYYYMMDD/ensembl_gencode_YYYYMMDD # path to local BLAST database (without .nhr extension)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/configs/vqa_config.yaml
DELETED
|
@@ -1,32 +0,0 @@
|
|
| 1 |
-
pipeline:
|
| 2 |
-
- name: read_step
|
| 3 |
-
op_key: read
|
| 4 |
-
params:
|
| 5 |
-
input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
|
| 6 |
-
|
| 7 |
-
- name: chunk_step
|
| 8 |
-
op_key: chunk
|
| 9 |
-
deps: [read_step] # chunk_step depends on read_step
|
| 10 |
-
params:
|
| 11 |
-
chunk_size: 1024 # chunk size for text splitting
|
| 12 |
-
chunk_overlap: 100 # chunk overlap for text splitting
|
| 13 |
-
|
| 14 |
-
- name: build_kg_step
|
| 15 |
-
op_key: build_kg
|
| 16 |
-
deps: [chunk_step] # build_kg depends on chunk_step
|
| 17 |
-
|
| 18 |
-
- name: partition_step
|
| 19 |
-
op_key: partition
|
| 20 |
-
deps: [build_kg_step] # partition_step depends on build_kg_step
|
| 21 |
-
params:
|
| 22 |
-
method: anchor_bfs # partition method
|
| 23 |
-
method_params:
|
| 24 |
-
anchor_type: image # node type to select anchor nodes
|
| 25 |
-
max_units_per_community: 10 # atomic partition, one node or edge per community
|
| 26 |
-
|
| 27 |
-
- name: generate_step
|
| 28 |
-
op_key: generate
|
| 29 |
-
deps: [partition_step] # generate_step depends on partition_step
|
| 30 |
-
params:
|
| 31 |
-
method: vqa # atomic, aggregated, multi_hop, cot, vqa
|
| 32 |
-
data_format: ChatML # Alpaca, Sharegpt, ChatML
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/engine.py
CHANGED
|
@@ -1,125 +1,210 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
import
|
| 6 |
-
import
|
| 7 |
-
from typing import Any, Callable, List
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
class Context(dict):
|
| 11 |
-
_lock = threading.Lock()
|
| 12 |
|
| 13 |
-
|
| 14 |
-
with self._lock:
|
| 15 |
-
self[k] = v
|
| 16 |
-
|
| 17 |
-
def get(self, k, default=None):
|
| 18 |
-
with self._lock:
|
| 19 |
-
return super().get(k, default)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class OpNode:
|
| 23 |
def __init__(
|
| 24 |
-
self,
|
| 25 |
):
|
| 26 |
-
self.
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
|
| 30 |
-
def
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
raise ValueError(
|
| 51 |
-
"
|
| 52 |
-
"Please check your configuration."
|
| 53 |
)
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
def _validate(ops: List[OpNode]):
|
| 87 |
-
name_set = set()
|
| 88 |
-
for op in ops:
|
| 89 |
-
if op.name in name_set:
|
| 90 |
-
raise ValueError(f"Duplicate operation name: {op.name}")
|
| 91 |
-
name_set.add(op.name)
|
| 92 |
-
for op in ops:
|
| 93 |
-
for dep in op.deps:
|
| 94 |
-
if dep not in name_set:
|
| 95 |
-
raise ValueError(
|
| 96 |
-
f"Operation {op.name} has unknown dependency: {dep}"
|
| 97 |
-
)
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
for stage in config["pipeline"]:
|
| 108 |
-
name = stage["name"]
|
| 109 |
-
method_name = stage.get("op_key")
|
| 110 |
-
method = getattr(graph_gen, method_name)
|
| 111 |
-
deps = stage.get("deps", [])
|
| 112 |
|
| 113 |
-
|
|
|
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
|
| 118 |
-
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
|
|
|
| 122 |
|
| 123 |
-
|
| 124 |
-
ops.append(op_node)
|
| 125 |
-
return ops
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import logging
|
| 3 |
+
from collections import defaultdict, deque
|
| 4 |
+
from functools import wraps
|
| 5 |
+
from typing import Any, Callable, Dict, List, Set
|
| 6 |
|
| 7 |
+
import ray
|
| 8 |
+
import ray.data
|
|
|
|
| 9 |
|
| 10 |
+
from graphgen.bases import Config, Node
|
| 11 |
+
from graphgen.utils import logger
|
| 12 |
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
class Engine:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def __init__(
|
| 16 |
+
self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs
|
| 17 |
):
|
| 18 |
+
self.config = Config(**config)
|
| 19 |
+
self.global_params = self.config.global_params
|
| 20 |
+
self.functions = functions
|
| 21 |
+
self.datasets: Dict[str, ray.data.Dataset] = {}
|
| 22 |
+
|
| 23 |
+
if not ray.is_initialized():
|
| 24 |
+
context = ray.init(
|
| 25 |
+
ignore_reinit_error=True,
|
| 26 |
+
logging_level=logging.ERROR,
|
| 27 |
+
log_to_driver=True,
|
| 28 |
+
**ray_init_kwargs,
|
| 29 |
+
)
|
| 30 |
+
logger.info("Ray Dashboard URL: %s", context.dashboard_url)
|
| 31 |
|
| 32 |
+
@staticmethod
|
| 33 |
+
def _topo_sort(nodes: List[Node]) -> List[Node]:
|
| 34 |
+
id_to_node: Dict[str, Node] = {}
|
| 35 |
+
for n in nodes:
|
| 36 |
+
id_to_node[n.id] = n
|
| 37 |
+
|
| 38 |
+
indeg: Dict[str, int] = {nid: 0 for nid in id_to_node}
|
| 39 |
+
adj: Dict[str, List[str]] = defaultdict(list)
|
| 40 |
+
|
| 41 |
+
for n in nodes:
|
| 42 |
+
nid = n.id
|
| 43 |
+
deps: List[str] = n.dependencies
|
| 44 |
+
uniq_deps: Set[str] = set(deps)
|
| 45 |
+
for d in uniq_deps:
|
| 46 |
+
if d not in id_to_node:
|
| 47 |
+
raise ValueError(
|
| 48 |
+
f"The dependency node id {d} of node {nid} is not defined in the configuration."
|
| 49 |
+
)
|
| 50 |
+
indeg[nid] += 1
|
| 51 |
+
adj[d].append(nid)
|
| 52 |
+
|
| 53 |
+
zero_deg: deque = deque(
|
| 54 |
+
[id_to_node[nid] for nid, deg in indeg.items() if deg == 0]
|
| 55 |
+
)
|
| 56 |
+
sorted_nodes: List[Node] = []
|
| 57 |
+
|
| 58 |
+
while zero_deg:
|
| 59 |
+
cur = zero_deg.popleft()
|
| 60 |
+
sorted_nodes.append(cur)
|
| 61 |
+
cur_id = cur.id
|
| 62 |
+
for nb_id in adj.get(cur_id, []):
|
| 63 |
+
indeg[nb_id] -= 1
|
| 64 |
+
if indeg[nb_id] == 0:
|
| 65 |
+
zero_deg.append(id_to_node[nb_id])
|
| 66 |
+
|
| 67 |
+
if len(sorted_nodes) != len(nodes):
|
| 68 |
+
remaining = [nid for nid, deg in indeg.items() if deg > 0]
|
| 69 |
raise ValueError(
|
| 70 |
+
f"The configuration contains cycles, unable to execute. Remaining nodes with indegree > 0: {remaining}"
|
|
|
|
| 71 |
)
|
| 72 |
|
| 73 |
+
return sorted_nodes
|
| 74 |
+
|
| 75 |
+
def _get_input_dataset(
|
| 76 |
+
self, node: Node, initial_ds: ray.data.Dataset
|
| 77 |
+
) -> ray.data.Dataset:
|
| 78 |
+
deps = node.dependencies
|
| 79 |
+
|
| 80 |
+
if not deps:
|
| 81 |
+
return initial_ds
|
| 82 |
+
|
| 83 |
+
if len(deps) == 1:
|
| 84 |
+
return self.datasets[deps[0]]
|
| 85 |
+
|
| 86 |
+
main_ds = self.datasets[deps[0]]
|
| 87 |
+
other_dss = [self.datasets[d] for d in deps[1:]]
|
| 88 |
+
return main_ds.union(*other_dss)
|
| 89 |
+
|
| 90 |
+
def _execute_node(self, node: Node, initial_ds: ray.data.Dataset):
|
| 91 |
+
def _filter_kwargs(
|
| 92 |
+
func_or_class: Callable,
|
| 93 |
+
global_params: Dict[str, Any],
|
| 94 |
+
func_params: Dict[str, Any],
|
| 95 |
+
) -> Dict[str, Any]:
|
| 96 |
+
"""
|
| 97 |
+
1. global_params: only when specified in function signature, will be passed
|
| 98 |
+
2. func_params: pass specified params first, then **kwargs if exists
|
| 99 |
+
"""
|
| 100 |
+
try:
|
| 101 |
+
sig = inspect.signature(func_or_class)
|
| 102 |
+
except ValueError:
|
| 103 |
+
return {}
|
| 104 |
+
|
| 105 |
+
params = sig.parameters
|
| 106 |
+
final_kwargs = {}
|
| 107 |
+
|
| 108 |
+
has_var_keywords = any(
|
| 109 |
+
p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
|
| 110 |
+
)
|
| 111 |
+
valid_keys = set(params.keys())
|
| 112 |
+
for k, v in global_params.items():
|
| 113 |
+
if k in valid_keys:
|
| 114 |
+
final_kwargs[k] = v
|
| 115 |
+
|
| 116 |
+
for k, v in func_params.items():
|
| 117 |
+
if k in valid_keys or has_var_keywords:
|
| 118 |
+
final_kwargs[k] = v
|
| 119 |
+
return final_kwargs
|
| 120 |
+
|
| 121 |
+
if node.op_name not in self.functions:
|
| 122 |
+
raise ValueError(f"Operator {node.op_name} not found for node {node.id}")
|
| 123 |
+
|
| 124 |
+
op_handler = self.functions[node.op_name]
|
| 125 |
+
node_params = _filter_kwargs(op_handler, self.global_params, node.params or {})
|
| 126 |
+
|
| 127 |
+
if node.type == "source":
|
| 128 |
+
self.datasets[node.id] = op_handler(**node_params)
|
| 129 |
+
return
|
| 130 |
+
|
| 131 |
+
input_ds = self._get_input_dataset(node, initial_ds)
|
| 132 |
+
|
| 133 |
+
if inspect.isclass(op_handler):
|
| 134 |
+
execution_params = node.execution_params or {}
|
| 135 |
+
replicas = execution_params.get("replicas", 1)
|
| 136 |
+
batch_size = (
|
| 137 |
+
int(execution_params.get("batch_size"))
|
| 138 |
+
if "batch_size" in execution_params
|
| 139 |
+
else "default"
|
| 140 |
)
|
| 141 |
+
compute_resources = execution_params.get("compute_resources", {})
|
| 142 |
+
|
| 143 |
+
if node.type == "aggregate":
|
| 144 |
+
self.datasets[node.id] = input_ds.repartition(1).map_batches(
|
| 145 |
+
op_handler,
|
| 146 |
+
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1),
|
| 147 |
+
batch_size=None, # aggregate processes the whole dataset at once
|
| 148 |
+
num_gpus=compute_resources.get("num_gpus", 0)
|
| 149 |
+
if compute_resources
|
| 150 |
+
else 0,
|
| 151 |
+
fn_constructor_kwargs=node_params,
|
| 152 |
+
batch_format="pandas",
|
| 153 |
+
)
|
| 154 |
+
else:
|
| 155 |
+
# others like map, filter, flatmap, map_batch let actors process data inside batches
|
| 156 |
+
self.datasets[node.id] = input_ds.map_batches(
|
| 157 |
+
op_handler,
|
| 158 |
+
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas),
|
| 159 |
+
batch_size=batch_size,
|
| 160 |
+
num_gpus=compute_resources.get("num_gpus", 0)
|
| 161 |
+
if compute_resources
|
| 162 |
+
else 0,
|
| 163 |
+
fn_constructor_kwargs=node_params,
|
| 164 |
+
batch_format="pandas",
|
| 165 |
+
)
|
| 166 |
|
| 167 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
@wraps(op_handler)
|
| 170 |
+
def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
|
| 171 |
+
return op_handler(row_or_batch, **node_params)
|
| 172 |
+
|
| 173 |
+
if node.type == "map":
|
| 174 |
+
self.datasets[node.id] = input_ds.map(func_wrapper)
|
| 175 |
+
elif node.type == "filter":
|
| 176 |
+
self.datasets[node.id] = input_ds.filter(func_wrapper)
|
| 177 |
+
elif node.type == "flatmap":
|
| 178 |
+
self.datasets[node.id] = input_ds.flat_map(func_wrapper)
|
| 179 |
+
elif node.type == "aggregate":
|
| 180 |
+
self.datasets[node.id] = input_ds.repartition(1).map_batches(
|
| 181 |
+
func_wrapper, batch_format="default"
|
| 182 |
+
)
|
| 183 |
+
elif node.type == "map_batch":
|
| 184 |
+
self.datasets[node.id] = input_ds.map_batches(func_wrapper)
|
| 185 |
+
else:
|
| 186 |
+
raise ValueError(
|
| 187 |
+
f"Unsupported node type {node.type} for node {node.id}"
|
| 188 |
+
)
|
| 189 |
|
| 190 |
+
@staticmethod
|
| 191 |
+
def _find_leaf_nodes(nodes: List[Node]) -> Set[str]:
|
| 192 |
+
all_ids = {n.id for n in nodes}
|
| 193 |
+
deps_set = set()
|
| 194 |
+
for n in nodes:
|
| 195 |
+
deps_set.update(n.dependencies)
|
| 196 |
+
return all_ids - deps_set
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]:
|
| 199 |
+
sorted_nodes = self._topo_sort(self.config.nodes)
|
| 200 |
|
| 201 |
+
for node in sorted_nodes:
|
| 202 |
+
self._execute_node(node, initial_ds)
|
| 203 |
|
| 204 |
+
leaf_nodes = self._find_leaf_nodes(sorted_nodes)
|
| 205 |
|
| 206 |
+
@ray.remote
|
| 207 |
+
def _fetch_result(ds: ray.data.Dataset) -> List[Any]:
|
| 208 |
+
return ds.take_all()
|
| 209 |
|
| 210 |
+
return {node_id: self.datasets[node_id] for node_id in leaf_nodes}
|
|
|
|
|
|
graphgen/graphgen.py
DELETED
|
@@ -1,295 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import time
|
| 3 |
-
from typing import Dict
|
| 4 |
-
|
| 5 |
-
import gradio as gr
|
| 6 |
-
|
| 7 |
-
from graphgen.bases import BaseLLMWrapper
|
| 8 |
-
from graphgen.bases.datatypes import Chunk
|
| 9 |
-
from graphgen.models import (
|
| 10 |
-
JsonKVStorage,
|
| 11 |
-
JsonListStorage,
|
| 12 |
-
NetworkXStorage,
|
| 13 |
-
OpenAIClient,
|
| 14 |
-
Tokenizer,
|
| 15 |
-
)
|
| 16 |
-
from graphgen.operators import (
|
| 17 |
-
build_kg,
|
| 18 |
-
chunk_documents,
|
| 19 |
-
extract_info,
|
| 20 |
-
generate_qas,
|
| 21 |
-
init_llm,
|
| 22 |
-
judge_statement,
|
| 23 |
-
partition_kg,
|
| 24 |
-
quiz,
|
| 25 |
-
read_files,
|
| 26 |
-
search_all,
|
| 27 |
-
)
|
| 28 |
-
from graphgen.utils import async_to_sync_method, compute_mm_hash, logger
|
| 29 |
-
|
| 30 |
-
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class GraphGen:
|
| 34 |
-
def __init__(
|
| 35 |
-
self,
|
| 36 |
-
unique_id: int = int(time.time()),
|
| 37 |
-
working_dir: str = os.path.join(sys_path, "cache"),
|
| 38 |
-
tokenizer_instance: Tokenizer = None,
|
| 39 |
-
synthesizer_llm_client: OpenAIClient = None,
|
| 40 |
-
trainee_llm_client: OpenAIClient = None,
|
| 41 |
-
progress_bar: gr.Progress = None,
|
| 42 |
-
):
|
| 43 |
-
self.unique_id: int = unique_id
|
| 44 |
-
self.working_dir: str = working_dir
|
| 45 |
-
|
| 46 |
-
# llm
|
| 47 |
-
self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer(
|
| 48 |
-
model_name=os.getenv("TOKENIZER_MODEL", "cl100k_base")
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
self.synthesizer_llm_client: BaseLLMWrapper = (
|
| 52 |
-
synthesizer_llm_client or init_llm("synthesizer")
|
| 53 |
-
)
|
| 54 |
-
self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client
|
| 55 |
-
|
| 56 |
-
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
|
| 57 |
-
self.working_dir, namespace="full_docs"
|
| 58 |
-
)
|
| 59 |
-
self.chunks_storage: JsonKVStorage = JsonKVStorage(
|
| 60 |
-
self.working_dir, namespace="chunks"
|
| 61 |
-
)
|
| 62 |
-
self.graph_storage: NetworkXStorage = NetworkXStorage(
|
| 63 |
-
self.working_dir, namespace="graph"
|
| 64 |
-
)
|
| 65 |
-
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
|
| 66 |
-
self.working_dir, namespace="rephrase"
|
| 67 |
-
)
|
| 68 |
-
self.partition_storage: JsonListStorage = JsonListStorage(
|
| 69 |
-
self.working_dir, namespace="partition"
|
| 70 |
-
)
|
| 71 |
-
self.search_storage: JsonKVStorage = JsonKVStorage(
|
| 72 |
-
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
|
| 73 |
-
namespace="search",
|
| 74 |
-
)
|
| 75 |
-
self.qa_storage: JsonListStorage = JsonListStorage(
|
| 76 |
-
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
|
| 77 |
-
namespace="qa",
|
| 78 |
-
)
|
| 79 |
-
self.extract_storage: JsonKVStorage = JsonKVStorage(
|
| 80 |
-
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
|
| 81 |
-
namespace="extraction",
|
| 82 |
-
)
|
| 83 |
-
|
| 84 |
-
# webui
|
| 85 |
-
self.progress_bar: gr.Progress = progress_bar
|
| 86 |
-
|
| 87 |
-
@async_to_sync_method
|
| 88 |
-
async def read(self, read_config: Dict):
|
| 89 |
-
"""
|
| 90 |
-
read files from input sources
|
| 91 |
-
"""
|
| 92 |
-
doc_stream = read_files(**read_config, cache_dir=self.working_dir)
|
| 93 |
-
|
| 94 |
-
batch = {}
|
| 95 |
-
for doc in doc_stream:
|
| 96 |
-
doc_id = compute_mm_hash(doc, prefix="doc-")
|
| 97 |
-
batch[doc_id] = doc
|
| 98 |
-
|
| 99 |
-
# TODO: configurable whether to use coreference resolution
|
| 100 |
-
|
| 101 |
-
_add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys()))
|
| 102 |
-
new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys}
|
| 103 |
-
if len(new_docs) == 0:
|
| 104 |
-
logger.warning("All documents are already in the storage")
|
| 105 |
-
return
|
| 106 |
-
self.full_docs_storage.upsert(new_docs)
|
| 107 |
-
self.full_docs_storage.index_done_callback()
|
| 108 |
-
|
| 109 |
-
@async_to_sync_method
|
| 110 |
-
async def chunk(self, chunk_config: Dict):
|
| 111 |
-
"""
|
| 112 |
-
chunk documents into smaller pieces from full_docs_storage if not already present
|
| 113 |
-
"""
|
| 114 |
-
|
| 115 |
-
new_docs = self.full_docs_storage.get_all()
|
| 116 |
-
if len(new_docs) == 0:
|
| 117 |
-
logger.warning("All documents are already in the storage")
|
| 118 |
-
return
|
| 119 |
-
|
| 120 |
-
inserting_chunks = await chunk_documents(
|
| 121 |
-
new_docs,
|
| 122 |
-
self.tokenizer_instance,
|
| 123 |
-
self.progress_bar,
|
| 124 |
-
**chunk_config,
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
_add_chunk_keys = self.chunks_storage.filter_keys(list(inserting_chunks.keys()))
|
| 128 |
-
inserting_chunks = {
|
| 129 |
-
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
| 130 |
-
}
|
| 131 |
-
|
| 132 |
-
if len(inserting_chunks) == 0:
|
| 133 |
-
logger.warning("All chunks are already in the storage")
|
| 134 |
-
return
|
| 135 |
-
|
| 136 |
-
self.chunks_storage.upsert(inserting_chunks)
|
| 137 |
-
self.chunks_storage.index_done_callback()
|
| 138 |
-
|
| 139 |
-
@async_to_sync_method
|
| 140 |
-
async def build_kg(self):
|
| 141 |
-
"""
|
| 142 |
-
build knowledge graph from text chunks
|
| 143 |
-
"""
|
| 144 |
-
# Step 1: get new chunks
|
| 145 |
-
inserting_chunks = self.chunks_storage.get_all()
|
| 146 |
-
|
| 147 |
-
if len(inserting_chunks) == 0:
|
| 148 |
-
logger.warning("All chunks are already in the storage")
|
| 149 |
-
return
|
| 150 |
-
|
| 151 |
-
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
|
| 152 |
-
# Step 2: build knowledge graph from new chunks
|
| 153 |
-
_add_entities_and_relations = await build_kg(
|
| 154 |
-
llm_client=self.synthesizer_llm_client,
|
| 155 |
-
kg_instance=self.graph_storage,
|
| 156 |
-
chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()],
|
| 157 |
-
progress_bar=self.progress_bar,
|
| 158 |
-
)
|
| 159 |
-
if not _add_entities_and_relations:
|
| 160 |
-
logger.warning("No entities or relations extracted from text chunks")
|
| 161 |
-
return
|
| 162 |
-
|
| 163 |
-
# Step 3: upsert new entities and relations to the graph storage
|
| 164 |
-
self.graph_storage.index_done_callback()
|
| 165 |
-
|
| 166 |
-
return _add_entities_and_relations
|
| 167 |
-
|
| 168 |
-
@async_to_sync_method
|
| 169 |
-
async def search(self, search_config: Dict):
|
| 170 |
-
logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))
|
| 171 |
-
|
| 172 |
-
seeds = self.full_docs_storage.get_all()
|
| 173 |
-
if len(seeds) == 0:
|
| 174 |
-
logger.warning("All documents are already been searched")
|
| 175 |
-
return
|
| 176 |
-
search_results = await search_all(
|
| 177 |
-
seed_data=seeds,
|
| 178 |
-
search_config=search_config,
|
| 179 |
-
)
|
| 180 |
-
|
| 181 |
-
_add_search_keys = self.search_storage.filter_keys(list(search_results.keys()))
|
| 182 |
-
search_results = {
|
| 183 |
-
k: v for k, v in search_results.items() if k in _add_search_keys
|
| 184 |
-
}
|
| 185 |
-
if len(search_results) == 0:
|
| 186 |
-
logger.warning("All search results are already in the storage")
|
| 187 |
-
return
|
| 188 |
-
self.search_storage.upsert(search_results)
|
| 189 |
-
self.search_storage.index_done_callback()
|
| 190 |
-
|
| 191 |
-
@async_to_sync_method
|
| 192 |
-
async def quiz_and_judge(self, quiz_and_judge_config: Dict):
|
| 193 |
-
logger.warning(
|
| 194 |
-
"Quiz and Judge operation needs trainee LLM client."
|
| 195 |
-
" Make sure to provide one."
|
| 196 |
-
)
|
| 197 |
-
max_samples = quiz_and_judge_config["quiz_samples"]
|
| 198 |
-
await quiz(
|
| 199 |
-
self.synthesizer_llm_client,
|
| 200 |
-
self.graph_storage,
|
| 201 |
-
self.rephrase_storage,
|
| 202 |
-
max_samples,
|
| 203 |
-
progress_bar=self.progress_bar,
|
| 204 |
-
)
|
| 205 |
-
|
| 206 |
-
# TODO: assert trainee_llm_client is valid before judge
|
| 207 |
-
if not self.trainee_llm_client:
|
| 208 |
-
# TODO: shutdown existing synthesizer_llm_client properly
|
| 209 |
-
logger.info("No trainee LLM client provided, initializing a new one.")
|
| 210 |
-
self.synthesizer_llm_client.shutdown()
|
| 211 |
-
self.trainee_llm_client = init_llm("trainee")
|
| 212 |
-
|
| 213 |
-
re_judge = quiz_and_judge_config["re_judge"]
|
| 214 |
-
_update_relations = await judge_statement(
|
| 215 |
-
self.trainee_llm_client,
|
| 216 |
-
self.graph_storage,
|
| 217 |
-
self.rephrase_storage,
|
| 218 |
-
re_judge,
|
| 219 |
-
progress_bar=self.progress_bar,
|
| 220 |
-
)
|
| 221 |
-
|
| 222 |
-
self.rephrase_storage.index_done_callback()
|
| 223 |
-
_update_relations.index_done_callback()
|
| 224 |
-
|
| 225 |
-
logger.info("Shutting down trainee LLM client.")
|
| 226 |
-
self.trainee_llm_client.shutdown()
|
| 227 |
-
self.trainee_llm_client = None
|
| 228 |
-
logger.info("Restarting synthesizer LLM client.")
|
| 229 |
-
self.synthesizer_llm_client.restart()
|
| 230 |
-
|
| 231 |
-
@async_to_sync_method
|
| 232 |
-
async def partition(self, partition_config: Dict):
|
| 233 |
-
batches = await partition_kg(
|
| 234 |
-
self.graph_storage,
|
| 235 |
-
self.chunks_storage,
|
| 236 |
-
self.tokenizer_instance,
|
| 237 |
-
partition_config,
|
| 238 |
-
)
|
| 239 |
-
self.partition_storage.upsert(batches)
|
| 240 |
-
return batches
|
| 241 |
-
|
| 242 |
-
@async_to_sync_method
|
| 243 |
-
async def extract(self, extract_config: Dict):
|
| 244 |
-
logger.info("Extracting information from given chunks...")
|
| 245 |
-
|
| 246 |
-
results = await extract_info(
|
| 247 |
-
self.synthesizer_llm_client,
|
| 248 |
-
self.chunks_storage,
|
| 249 |
-
extract_config,
|
| 250 |
-
progress_bar=self.progress_bar,
|
| 251 |
-
)
|
| 252 |
-
if not results:
|
| 253 |
-
logger.warning("No information extracted")
|
| 254 |
-
return
|
| 255 |
-
|
| 256 |
-
self.extract_storage.upsert(results)
|
| 257 |
-
self.extract_storage.index_done_callback()
|
| 258 |
-
|
| 259 |
-
@async_to_sync_method
|
| 260 |
-
async def generate(self, generate_config: Dict):
|
| 261 |
-
|
| 262 |
-
batches = self.partition_storage.data
|
| 263 |
-
if not batches:
|
| 264 |
-
logger.warning("No partitions found for QA generation")
|
| 265 |
-
return
|
| 266 |
-
|
| 267 |
-
# Step 2: generate QA pairs
|
| 268 |
-
results = await generate_qas(
|
| 269 |
-
self.synthesizer_llm_client,
|
| 270 |
-
batches,
|
| 271 |
-
generate_config,
|
| 272 |
-
progress_bar=self.progress_bar,
|
| 273 |
-
)
|
| 274 |
-
|
| 275 |
-
if not results:
|
| 276 |
-
logger.warning("No QA pairs generated")
|
| 277 |
-
return
|
| 278 |
-
|
| 279 |
-
# Step 3: store the generated QA pairs
|
| 280 |
-
self.qa_storage.upsert(results)
|
| 281 |
-
self.qa_storage.index_done_callback()
|
| 282 |
-
|
| 283 |
-
@async_to_sync_method
|
| 284 |
-
async def clear(self):
|
| 285 |
-
self.full_docs_storage.drop()
|
| 286 |
-
self.chunks_storage.drop()
|
| 287 |
-
self.search_storage.drop()
|
| 288 |
-
self.graph_storage.clear()
|
| 289 |
-
self.rephrase_storage.drop()
|
| 290 |
-
self.qa_storage.drop()
|
| 291 |
-
|
| 292 |
-
logger.info("All caches are cleared")
|
| 293 |
-
|
| 294 |
-
# TODO: add data filtering step here in the future
|
| 295 |
-
# graph_gen.filter(filter_config=config["filter"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/__init__.py
CHANGED
|
@@ -18,7 +18,6 @@ from .partitioner import (
|
|
| 18 |
)
|
| 19 |
from .reader import (
|
| 20 |
CSVReader,
|
| 21 |
-
JSONLReader,
|
| 22 |
JSONReader,
|
| 23 |
ParquetReader,
|
| 24 |
PDFReader,
|
|
@@ -33,5 +32,11 @@ from .searcher.kg.wiki_search import WikiSearch
|
|
| 33 |
from .searcher.web.bing_search import BingSearch
|
| 34 |
from .searcher.web.google_search import GoogleSearch
|
| 35 |
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
|
| 36 |
-
from .storage import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
from .tokenizer import Tokenizer
|
|
|
|
| 18 |
)
|
| 19 |
from .reader import (
|
| 20 |
CSVReader,
|
|
|
|
| 21 |
JSONReader,
|
| 22 |
ParquetReader,
|
| 23 |
PDFReader,
|
|
|
|
| 32 |
from .searcher.web.bing_search import BingSearch
|
| 33 |
from .searcher.web.google_search import GoogleSearch
|
| 34 |
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
|
| 35 |
+
from .storage import (
|
| 36 |
+
JsonKVStorage,
|
| 37 |
+
KuzuStorage,
|
| 38 |
+
NetworkXStorage,
|
| 39 |
+
RocksDBCache,
|
| 40 |
+
RocksDBKVStorage,
|
| 41 |
+
)
|
| 42 |
from .tokenizer import Tokenizer
|
graphgen/models/extractor/schema_guided_extractor.py
CHANGED
|
@@ -60,8 +60,8 @@ class SchemaGuidedExtractor(BaseExtractor):
|
|
| 60 |
return prompt
|
| 61 |
|
| 62 |
async def extract(self, chunk: dict) -> dict:
|
| 63 |
-
_chunk_id =
|
| 64 |
-
text = chunk
|
| 65 |
|
| 66 |
prompt = self.build_prompt(text)
|
| 67 |
response = await self.llm_client.generate_answer(prompt)
|
|
@@ -88,9 +88,7 @@ class SchemaGuidedExtractor(BaseExtractor):
|
|
| 88 |
return {}
|
| 89 |
|
| 90 |
@staticmethod
|
| 91 |
-
|
| 92 |
-
extraction_list: List[Dict[str, dict]]
|
| 93 |
-
) -> Dict[str, dict]:
|
| 94 |
"""
|
| 95 |
Merge multiple extraction results based on their hashes.
|
| 96 |
:param extraction_list: List of extraction results, each is a dict with hash as key and record as value.
|
|
|
|
| 60 |
return prompt
|
| 61 |
|
| 62 |
async def extract(self, chunk: dict) -> dict:
|
| 63 |
+
_chunk_id = chunk.get("_chunk_id", "")
|
| 64 |
+
text = chunk.get("content", "")
|
| 65 |
|
| 66 |
prompt = self.build_prompt(text)
|
| 67 |
response = await self.llm_client.generate_answer(prompt)
|
|
|
|
| 88 |
return {}
|
| 89 |
|
| 90 |
@staticmethod
|
| 91 |
+
def merge_extractions(extraction_list: List[Dict[str, dict]]) -> Dict[str, dict]:
|
|
|
|
|
|
|
| 92 |
"""
|
| 93 |
Merge multiple extraction results based on their hashes.
|
| 94 |
:param extraction_list: List of extraction results, each is a dict with hash as key and record as value.
|
graphgen/models/generator/vqa_generator.py
CHANGED
|
@@ -77,8 +77,8 @@ class VQAGenerator(BaseGenerator):
|
|
| 77 |
nodes, _ = batch
|
| 78 |
for node in nodes:
|
| 79 |
node_data = node[1]
|
| 80 |
-
if "
|
| 81 |
-
img_path = node_data["
|
| 82 |
for qa in qa_pairs.values():
|
| 83 |
qa["img_path"] = img_path
|
| 84 |
result.update(qa_pairs)
|
|
|
|
| 77 |
nodes, _ = batch
|
| 78 |
for node in nodes:
|
| 79 |
node_data = node[1]
|
| 80 |
+
if "image_data" in node_data and node_data["image_data"]:
|
| 81 |
+
img_path = node_data["image_data"]["img_path"]
|
| 82 |
for qa in qa_pairs.values():
|
| 83 |
qa["img_path"] = img_path
|
| 84 |
result.update(qa_pairs)
|
graphgen/models/llm/local/sglang_wrapper.py
CHANGED
|
@@ -138,15 +138,3 @@ class SGLangWrapper(BaseLLMWrapper):
|
|
| 138 |
raise NotImplementedError(
|
| 139 |
"SGLangWrapper does not support per-token logprobs yet."
|
| 140 |
)
|
| 141 |
-
|
| 142 |
-
def shutdown(self) -> None:
|
| 143 |
-
"""Gracefully shutdown the SGLang engine."""
|
| 144 |
-
if hasattr(self, "engine"):
|
| 145 |
-
self.engine.shutdown()
|
| 146 |
-
|
| 147 |
-
def restart(self) -> None:
|
| 148 |
-
"""Restart the SGLang engine."""
|
| 149 |
-
self.shutdown()
|
| 150 |
-
self.engine = self.engine.__class__(
|
| 151 |
-
model_path=self.model_path, tp_size=self.tp_size
|
| 152 |
-
)
|
|
|
|
| 138 |
raise NotImplementedError(
|
| 139 |
"SGLangWrapper does not support per-token logprobs yet."
|
| 140 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/llm/local/vllm_wrapper.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
from typing import Any, List, Optional
|
| 2 |
|
| 3 |
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
|
|
@@ -6,7 +8,7 @@ from graphgen.bases.datatypes import Token
|
|
| 6 |
|
| 7 |
class VLLMWrapper(BaseLLMWrapper):
|
| 8 |
"""
|
| 9 |
-
Async inference backend based on vLLM
|
| 10 |
"""
|
| 11 |
|
| 12 |
def __init__(
|
|
@@ -20,12 +22,11 @@ class VLLMWrapper(BaseLLMWrapper):
|
|
| 20 |
**kwargs: Any,
|
| 21 |
):
|
| 22 |
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
|
| 23 |
-
|
| 24 |
try:
|
| 25 |
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
|
| 26 |
except ImportError as exc:
|
| 27 |
raise ImportError(
|
| 28 |
-
"VLLMWrapper requires vllm. Install it with:
|
| 29 |
) from exc
|
| 30 |
|
| 31 |
self.SamplingParams = SamplingParams
|
|
@@ -35,9 +36,9 @@ class VLLMWrapper(BaseLLMWrapper):
|
|
| 35 |
tensor_parallel_size=tensor_parallel_size,
|
| 36 |
gpu_memory_utilization=gpu_memory_utilization,
|
| 37 |
trust_remote_code=kwargs.get("trust_remote_code", True),
|
|
|
|
| 38 |
)
|
| 39 |
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
| 40 |
-
|
| 41 |
self.temperature = temperature
|
| 42 |
self.top_p = top_p
|
| 43 |
self.topk = topk
|
|
@@ -60,6 +61,7 @@ class VLLMWrapper(BaseLLMWrapper):
|
|
| 60 |
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 61 |
) -> str:
|
| 62 |
full_prompt = self._build_inputs(text, history)
|
|
|
|
| 63 |
|
| 64 |
sp = self.SamplingParams(
|
| 65 |
temperature=self.temperature if self.temperature > 0 else 1.0,
|
|
@@ -67,71 +69,57 @@ class VLLMWrapper(BaseLLMWrapper):
|
|
| 67 |
max_tokens=extra.get("max_new_tokens", 512),
|
| 68 |
)
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
async def generate_topk_per_token(
|
| 78 |
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 79 |
) -> List[Token]:
|
| 80 |
full_prompt = self._build_inputs(text, history)
|
| 81 |
|
|
|
|
|
|
|
| 82 |
sp = self.SamplingParams(
|
| 83 |
temperature=0,
|
| 84 |
max_tokens=1,
|
| 85 |
logprobs=self.topk,
|
| 86 |
)
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
):
|
| 92 |
-
|
| 93 |
-
|
|
|
|
| 94 |
|
| 95 |
tokens = []
|
| 96 |
for _, logprob_obj in top_logprobs.items():
|
| 97 |
tok_str = logprob_obj.decoded_token
|
| 98 |
-
prob = float(logprob_obj.logprob
|
| 99 |
tokens.append(Token(tok_str, prob))
|
|
|
|
| 100 |
tokens.sort(key=lambda x: -x.prob)
|
| 101 |
return tokens
|
| 102 |
|
| 103 |
async def generate_inputs_prob(
|
| 104 |
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 105 |
) -> List[Token]:
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# vLLM 没有现成的“mask 一个 token 再算 prob”接口,
|
| 109 |
-
# 我们采用最直观的方式:把 prompt 一次性送进去,打开
|
| 110 |
-
# prompt_logprobs=True,让 vLLM 返回 *输入部分* 每个位置的
|
| 111 |
-
# logprob,然后挑出对应 token 的概率即可。
|
| 112 |
-
sp = self.SamplingParams(
|
| 113 |
-
temperature=0,
|
| 114 |
-
max_tokens=0, # 不生成新 token
|
| 115 |
-
prompt_logprobs=1, # 只要 top-1 就够了
|
| 116 |
)
|
| 117 |
-
|
| 118 |
-
results = []
|
| 119 |
-
async for req_output in self.engine.generate(
|
| 120 |
-
full_prompt, sp, request_id="graphgen_prob"
|
| 121 |
-
):
|
| 122 |
-
results = req_output.outputs
|
| 123 |
-
|
| 124 |
-
# prompt_logprobs 是一个 list,长度 = prompt token 数,
|
| 125 |
-
# 每个元素是 dict{token_id: logprob_obj} 或 None(首个位置为 None)
|
| 126 |
-
prompt_logprobs = results[-1].prompt_logprobs
|
| 127 |
-
|
| 128 |
-
tokens = []
|
| 129 |
-
for _, logprob_dict in enumerate(prompt_logprobs):
|
| 130 |
-
if logprob_dict is None:
|
| 131 |
-
continue
|
| 132 |
-
# 这里每个 dict 只有 1 个 kv,因为 top-1
|
| 133 |
-
_, logprob_obj = next(iter(logprob_dict.items()))
|
| 134 |
-
tok_str = logprob_obj.decoded_token
|
| 135 |
-
prob = float(logprob_obj.logprob.exp())
|
| 136 |
-
tokens.append(Token(tok_str, prob))
|
| 137 |
-
return tokens
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import uuid
|
| 3 |
from typing import Any, List, Optional
|
| 4 |
|
| 5 |
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
|
|
|
|
| 8 |
|
| 9 |
class VLLMWrapper(BaseLLMWrapper):
|
| 10 |
"""
|
| 11 |
+
Async inference backend based on vLLM.
|
| 12 |
"""
|
| 13 |
|
| 14 |
def __init__(
|
|
|
|
| 22 |
**kwargs: Any,
|
| 23 |
):
|
| 24 |
super().__init__(temperature=temperature, top_p=top_p, **kwargs)
|
|
|
|
| 25 |
try:
|
| 26 |
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
|
| 27 |
except ImportError as exc:
|
| 28 |
raise ImportError(
|
| 29 |
+
"VLLMWrapper requires vllm. Install it with: uv pip install vllm"
|
| 30 |
) from exc
|
| 31 |
|
| 32 |
self.SamplingParams = SamplingParams
|
|
|
|
| 36 |
tensor_parallel_size=tensor_parallel_size,
|
| 37 |
gpu_memory_utilization=gpu_memory_utilization,
|
| 38 |
trust_remote_code=kwargs.get("trust_remote_code", True),
|
| 39 |
+
disable_log_stats=False,
|
| 40 |
)
|
| 41 |
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
|
|
| 42 |
self.temperature = temperature
|
| 43 |
self.top_p = top_p
|
| 44 |
self.topk = topk
|
|
|
|
| 61 |
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 62 |
) -> str:
|
| 63 |
full_prompt = self._build_inputs(text, history)
|
| 64 |
+
request_id = f"graphgen_req_{uuid.uuid4()}"
|
| 65 |
|
| 66 |
sp = self.SamplingParams(
|
| 67 |
temperature=self.temperature if self.temperature > 0 else 1.0,
|
|
|
|
| 69 |
max_tokens=extra.get("max_new_tokens", 512),
|
| 70 |
)
|
| 71 |
|
| 72 |
+
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
|
| 73 |
+
|
| 74 |
+
final_output = None
|
| 75 |
+
async for request_output in result_generator:
|
| 76 |
+
final_output = request_output
|
| 77 |
+
|
| 78 |
+
if not final_output or not final_output.outputs:
|
| 79 |
+
return ""
|
| 80 |
+
|
| 81 |
+
return final_output.outputs[0].text
|
| 82 |
|
| 83 |
async def generate_topk_per_token(
|
| 84 |
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 85 |
) -> List[Token]:
|
| 86 |
full_prompt = self._build_inputs(text, history)
|
| 87 |
|
| 88 |
+
request_id = f"graphgen_topk_{uuid.uuid4()}"
|
| 89 |
+
|
| 90 |
sp = self.SamplingParams(
|
| 91 |
temperature=0,
|
| 92 |
max_tokens=1,
|
| 93 |
logprobs=self.topk,
|
| 94 |
)
|
| 95 |
|
| 96 |
+
result_generator = self.engine.generate(full_prompt, sp, request_id=request_id)
|
| 97 |
+
|
| 98 |
+
final_output = None
|
| 99 |
+
async for request_output in result_generator:
|
| 100 |
+
final_output = request_output
|
| 101 |
+
|
| 102 |
+
if (
|
| 103 |
+
not final_output
|
| 104 |
+
or not final_output.outputs
|
| 105 |
+
or not final_output.outputs[0].logprobs
|
| 106 |
):
|
| 107 |
+
return []
|
| 108 |
+
|
| 109 |
+
top_logprobs = final_output.outputs[0].logprobs[0]
|
| 110 |
|
| 111 |
tokens = []
|
| 112 |
for _, logprob_obj in top_logprobs.items():
|
| 113 |
tok_str = logprob_obj.decoded_token
|
| 114 |
+
prob = float(math.exp(logprob_obj.logprob))
|
| 115 |
tokens.append(Token(tok_str, prob))
|
| 116 |
+
|
| 117 |
tokens.sort(key=lambda x: -x.prob)
|
| 118 |
return tokens
|
| 119 |
|
| 120 |
async def generate_inputs_prob(
|
| 121 |
self, text: str, history: Optional[List[str]] = None, **extra: Any
|
| 122 |
) -> List[Token]:
|
| 123 |
+
raise NotImplementedError(
|
| 124 |
+
"VLLMWrapper does not support per-token logprobs yet."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/partitioner/anchor_bfs_partitioner.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import random
|
| 2 |
from collections import deque
|
| 3 |
-
from typing import Any, List, Literal, Set, Tuple
|
| 4 |
|
| 5 |
from graphgen.bases import BaseGraphStorage
|
| 6 |
from graphgen.bases.datatypes import Community
|
|
@@ -30,24 +30,23 @@ class AnchorBFSPartitioner(BFSPartitioner):
|
|
| 30 |
self.anchor_type = anchor_type
|
| 31 |
self.anchor_ids = anchor_ids
|
| 32 |
|
| 33 |
-
|
| 34 |
self,
|
| 35 |
g: BaseGraphStorage,
|
| 36 |
max_units_per_community: int = 1,
|
| 37 |
**kwargs: Any,
|
| 38 |
-
) ->
|
| 39 |
nodes = g.get_all_nodes() # List[tuple[id, meta]]
|
| 40 |
edges = g.get_all_edges() # List[tuple[u, v, meta]]
|
| 41 |
|
| 42 |
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 43 |
|
| 44 |
-
anchors: Set[str] =
|
| 45 |
if not anchors:
|
| 46 |
-
return
|
| 47 |
|
| 48 |
used_n: set[str] = set()
|
| 49 |
used_e: set[frozenset[str]] = set()
|
| 50 |
-
communities: List[Community] = []
|
| 51 |
|
| 52 |
seeds = list(anchors)
|
| 53 |
random.shuffle(seeds)
|
|
@@ -55,17 +54,13 @@ class AnchorBFSPartitioner(BFSPartitioner):
|
|
| 55 |
for seed_node in seeds:
|
| 56 |
if seed_node in used_n:
|
| 57 |
continue
|
| 58 |
-
comm_n, comm_e =
|
| 59 |
seed_node, adj, max_units_per_community, used_n, used_e
|
| 60 |
)
|
| 61 |
if comm_n or comm_e:
|
| 62 |
-
|
| 63 |
-
Community(id=len(communities), nodes=comm_n, edges=comm_e)
|
| 64 |
-
)
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
async def _pick_anchor_ids(
|
| 69 |
self,
|
| 70 |
nodes: List[tuple[str, dict]],
|
| 71 |
) -> Set[str]:
|
|
@@ -80,7 +75,7 @@ class AnchorBFSPartitioner(BFSPartitioner):
|
|
| 80 |
return anchor_ids
|
| 81 |
|
| 82 |
@staticmethod
|
| 83 |
-
|
| 84 |
seed: str,
|
| 85 |
adj: dict[str, List[str]],
|
| 86 |
max_units: int,
|
|
|
|
| 1 |
import random
|
| 2 |
from collections import deque
|
| 3 |
+
from typing import Any, Iterable, List, Literal, Set, Tuple
|
| 4 |
|
| 5 |
from graphgen.bases import BaseGraphStorage
|
| 6 |
from graphgen.bases.datatypes import Community
|
|
|
|
| 30 |
self.anchor_type = anchor_type
|
| 31 |
self.anchor_ids = anchor_ids
|
| 32 |
|
| 33 |
+
def partition(
|
| 34 |
self,
|
| 35 |
g: BaseGraphStorage,
|
| 36 |
max_units_per_community: int = 1,
|
| 37 |
**kwargs: Any,
|
| 38 |
+
) -> Iterable[Community]:
|
| 39 |
nodes = g.get_all_nodes() # List[tuple[id, meta]]
|
| 40 |
edges = g.get_all_edges() # List[tuple[u, v, meta]]
|
| 41 |
|
| 42 |
adj, _ = self._build_adjacency_list(nodes, edges)
|
| 43 |
|
| 44 |
+
anchors: Set[str] = self._pick_anchor_ids(nodes)
|
| 45 |
if not anchors:
|
| 46 |
+
return # if no anchors, return nothing
|
| 47 |
|
| 48 |
used_n: set[str] = set()
|
| 49 |
used_e: set[frozenset[str]] = set()
|
|
|
|
| 50 |
|
| 51 |
seeds = list(anchors)
|
| 52 |
random.shuffle(seeds)
|
|
|
|
| 54 |
for seed_node in seeds:
|
| 55 |
if seed_node in used_n:
|
| 56 |
continue
|
| 57 |
+
comm_n, comm_e = self._grow_community(
|
| 58 |
seed_node, adj, max_units_per_community, used_n, used_e
|
| 59 |
)
|
| 60 |
if comm_n or comm_e:
|
| 61 |
+
yield Community(id=seed_node, nodes=comm_n, edges=comm_e)
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
def _pick_anchor_ids(
|
|
|
|
|
|
|
| 64 |
self,
|
| 65 |
nodes: List[tuple[str, dict]],
|
| 66 |
) -> Set[str]:
|
|
|
|
| 75 |
return anchor_ids
|
| 76 |
|
| 77 |
@staticmethod
|
| 78 |
+
def _grow_community(
|
| 79 |
seed: str,
|
| 80 |
adj: dict[str, List[str]],
|
| 81 |
max_units: int,
|
graphgen/models/partitioner/bfs_partitioner.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import random
|
| 2 |
from collections import deque
|
| 3 |
-
from typing import Any, List
|
| 4 |
|
| 5 |
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
| 6 |
from graphgen.bases.datatypes import Community
|
|
@@ -17,12 +17,12 @@ class BFSPartitioner(BasePartitioner):
|
|
| 17 |
(A unit is a node or an edge.)
|
| 18 |
"""
|
| 19 |
|
| 20 |
-
|
| 21 |
self,
|
| 22 |
g: BaseGraphStorage,
|
| 23 |
max_units_per_community: int = 1,
|
| 24 |
**kwargs: Any,
|
| 25 |
-
) ->
|
| 26 |
nodes = g.get_all_nodes()
|
| 27 |
edges = g.get_all_edges()
|
| 28 |
|
|
@@ -30,7 +30,6 @@ class BFSPartitioner(BasePartitioner):
|
|
| 30 |
|
| 31 |
used_n: set[str] = set()
|
| 32 |
used_e: set[frozenset[str]] = set()
|
| 33 |
-
communities: List[Community] = []
|
| 34 |
|
| 35 |
units = [(NODE_UNIT, n[0]) for n in nodes] + [
|
| 36 |
(EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
|
|
@@ -74,8 +73,4 @@ class BFSPartitioner(BasePartitioner):
|
|
| 74 |
queue.append((NODE_UNIT, n))
|
| 75 |
|
| 76 |
if comm_n or comm_e:
|
| 77 |
-
|
| 78 |
-
Community(id=len(communities), nodes=comm_n, edges=comm_e)
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
return communities
|
|
|
|
| 1 |
import random
|
| 2 |
from collections import deque
|
| 3 |
+
from typing import Any, Iterable, List
|
| 4 |
|
| 5 |
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
| 6 |
from graphgen.bases.datatypes import Community
|
|
|
|
| 17 |
(A unit is a node or an edge.)
|
| 18 |
"""
|
| 19 |
|
| 20 |
+
def partition(
|
| 21 |
self,
|
| 22 |
g: BaseGraphStorage,
|
| 23 |
max_units_per_community: int = 1,
|
| 24 |
**kwargs: Any,
|
| 25 |
+
) -> Iterable[Community]:
|
| 26 |
nodes = g.get_all_nodes()
|
| 27 |
edges = g.get_all_edges()
|
| 28 |
|
|
|
|
| 30 |
|
| 31 |
used_n: set[str] = set()
|
| 32 |
used_e: set[frozenset[str]] = set()
|
|
|
|
| 33 |
|
| 34 |
units = [(NODE_UNIT, n[0]) for n in nodes] + [
|
| 35 |
(EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
|
|
|
|
| 73 |
queue.append((NODE_UNIT, n))
|
| 74 |
|
| 75 |
if comm_n or comm_e:
|
| 76 |
+
yield Community(id=seed, nodes=comm_n, edges=comm_e)
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/partitioner/dfs_partitioner.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import random
|
| 2 |
-
from
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
| 5 |
from graphgen.bases.datatypes import Community
|
|
@@ -16,12 +17,12 @@ class DFSPartitioner(BasePartitioner):
|
|
| 16 |
(In GraphGen, a unit is defined as a node or an edge.)
|
| 17 |
"""
|
| 18 |
|
| 19 |
-
|
| 20 |
self,
|
| 21 |
g: BaseGraphStorage,
|
| 22 |
max_units_per_community: int = 1,
|
| 23 |
**kwargs: Any,
|
| 24 |
-
) ->
|
| 25 |
nodes = g.get_all_nodes()
|
| 26 |
edges = g.get_all_edges()
|
| 27 |
|
|
@@ -29,7 +30,6 @@ class DFSPartitioner(BasePartitioner):
|
|
| 29 |
|
| 30 |
used_n: set[str] = set()
|
| 31 |
used_e: set[frozenset[str]] = set()
|
| 32 |
-
communities: List[Community] = []
|
| 33 |
|
| 34 |
units = [(NODE_UNIT, n[0]) for n in nodes] + [
|
| 35 |
(EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
|
|
@@ -71,8 +71,4 @@ class DFSPartitioner(BasePartitioner):
|
|
| 71 |
stack.append((NODE_UNIT, n))
|
| 72 |
|
| 73 |
if comm_n or comm_e:
|
| 74 |
-
|
| 75 |
-
Community(id=len(communities), nodes=comm_n, edges=comm_e)
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
return communities
|
|
|
|
| 1 |
import random
|
| 2 |
+
from collections.abc import Iterable
|
| 3 |
+
from typing import Any
|
| 4 |
|
| 5 |
from graphgen.bases import BaseGraphStorage, BasePartitioner
|
| 6 |
from graphgen.bases.datatypes import Community
|
|
|
|
| 17 |
(In GraphGen, a unit is defined as a node or an edge.)
|
| 18 |
"""
|
| 19 |
|
| 20 |
+
def partition(
|
| 21 |
self,
|
| 22 |
g: BaseGraphStorage,
|
| 23 |
max_units_per_community: int = 1,
|
| 24 |
**kwargs: Any,
|
| 25 |
+
) -> Iterable[Community]:
|
| 26 |
nodes = g.get_all_nodes()
|
| 27 |
edges = g.get_all_edges()
|
| 28 |
|
|
|
|
| 30 |
|
| 31 |
used_n: set[str] = set()
|
| 32 |
used_e: set[frozenset[str]] = set()
|
|
|
|
| 33 |
|
| 34 |
units = [(NODE_UNIT, n[0]) for n in nodes] + [
|
| 35 |
(EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
|
|
|
|
| 71 |
stack.append((NODE_UNIT, n))
|
| 72 |
|
| 73 |
if comm_n or comm_e:
|
| 74 |
+
yield Community(id=seed, nodes=comm_n, edges=comm_e)
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/partitioner/ece_partitioner.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
import asyncio
|
| 2 |
import random
|
| 3 |
-
from
|
|
|
|
| 4 |
|
| 5 |
-
from tqdm
|
| 6 |
|
| 7 |
from graphgen.bases import BaseGraphStorage
|
| 8 |
from graphgen.bases.datatypes import Community
|
|
@@ -51,7 +51,7 @@ class ECEPartitioner(BFSPartitioner):
|
|
| 51 |
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
|
| 52 |
return units
|
| 53 |
|
| 54 |
-
|
| 55 |
self,
|
| 56 |
g: BaseGraphStorage,
|
| 57 |
max_units_per_community: int = 10,
|
|
@@ -59,7 +59,7 @@ class ECEPartitioner(BFSPartitioner):
|
|
| 59 |
max_tokens_per_community: int = 10240,
|
| 60 |
unit_sampling: str = "random",
|
| 61 |
**kwargs: Any,
|
| 62 |
-
) ->
|
| 63 |
nodes: List[Tuple[str, dict]] = g.get_all_nodes()
|
| 64 |
edges: List[Tuple[str, str, dict]] = g.get_all_edges()
|
| 65 |
|
|
@@ -73,21 +73,18 @@ class ECEPartitioner(BFSPartitioner):
|
|
| 73 |
|
| 74 |
used_n: Set[str] = set()
|
| 75 |
used_e: Set[frozenset[str]] = set()
|
| 76 |
-
communities: List = []
|
| 77 |
|
| 78 |
all_units = self._sort_units(all_units, unit_sampling)
|
| 79 |
|
| 80 |
-
|
| 81 |
-
seed_unit: Tuple[str, Any, dict]
|
| 82 |
-
) -> Optional[Community]:
|
| 83 |
nonlocal used_n, used_e
|
| 84 |
|
| 85 |
community_nodes: Dict[str, dict] = {}
|
| 86 |
community_edges: Dict[frozenset[str], dict] = {}
|
| 87 |
-
queue
|
| 88 |
token_sum = 0
|
| 89 |
|
| 90 |
-
|
| 91 |
nonlocal token_sum
|
| 92 |
t, i, d = u
|
| 93 |
if t == NODE_UNIT: # node
|
|
@@ -103,11 +100,11 @@ class ECEPartitioner(BFSPartitioner):
|
|
| 103 |
token_sum += d.get("length", 0)
|
| 104 |
return True
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
|
| 109 |
# BFS
|
| 110 |
-
while
|
| 111 |
if (
|
| 112 |
len(community_nodes) + len(community_edges)
|
| 113 |
>= max_units_per_community
|
|
@@ -115,7 +112,7 @@ class ECEPartitioner(BFSPartitioner):
|
|
| 115 |
):
|
| 116 |
break
|
| 117 |
|
| 118 |
-
cur_type, cur_id, _ =
|
| 119 |
|
| 120 |
neighbors: List[Tuple[str, Any, dict]] = []
|
| 121 |
if cur_type == NODE_UNIT:
|
|
@@ -136,26 +133,24 @@ class ECEPartitioner(BFSPartitioner):
|
|
| 136 |
or token_sum >= max_tokens_per_community
|
| 137 |
):
|
| 138 |
break
|
| 139 |
-
if
|
| 140 |
-
|
| 141 |
|
| 142 |
if len(community_nodes) + len(community_edges) < min_units_per_community:
|
| 143 |
return None
|
| 144 |
|
| 145 |
return Community(
|
| 146 |
-
id=
|
| 147 |
nodes=list(community_nodes.keys()),
|
| 148 |
edges=[(u, v) for (u, v), _ in community_edges.items()],
|
| 149 |
)
|
| 150 |
|
| 151 |
-
|
| 152 |
utype, uid, _ = unit
|
| 153 |
if (utype == NODE_UNIT and uid in used_n) or (
|
| 154 |
utype == EDGE_UNIT and uid in used_e
|
| 155 |
):
|
| 156 |
continue
|
| 157 |
-
comm =
|
| 158 |
-
if comm
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
return communities
|
|
|
|
|
|
|
| 1 |
import random
|
| 2 |
+
from collections import deque
|
| 3 |
+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
| 4 |
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
|
| 7 |
from graphgen.bases import BaseGraphStorage
|
| 8 |
from graphgen.bases.datatypes import Community
|
|
|
|
| 51 |
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
|
| 52 |
return units
|
| 53 |
|
| 54 |
+
def partition(
|
| 55 |
self,
|
| 56 |
g: BaseGraphStorage,
|
| 57 |
max_units_per_community: int = 10,
|
|
|
|
| 59 |
max_tokens_per_community: int = 10240,
|
| 60 |
unit_sampling: str = "random",
|
| 61 |
**kwargs: Any,
|
| 62 |
+
) -> Iterable[Community]:
|
| 63 |
nodes: List[Tuple[str, dict]] = g.get_all_nodes()
|
| 64 |
edges: List[Tuple[str, str, dict]] = g.get_all_edges()
|
| 65 |
|
|
|
|
| 73 |
|
| 74 |
used_n: Set[str] = set()
|
| 75 |
used_e: Set[frozenset[str]] = set()
|
|
|
|
| 76 |
|
| 77 |
all_units = self._sort_units(all_units, unit_sampling)
|
| 78 |
|
| 79 |
+
def _grow_community(seed_unit: Tuple[str, Any, dict]) -> Optional[Community]:
|
|
|
|
|
|
|
| 80 |
nonlocal used_n, used_e
|
| 81 |
|
| 82 |
community_nodes: Dict[str, dict] = {}
|
| 83 |
community_edges: Dict[frozenset[str], dict] = {}
|
| 84 |
+
queue = deque()
|
| 85 |
token_sum = 0
|
| 86 |
|
| 87 |
+
def _add_unit(u):
|
| 88 |
nonlocal token_sum
|
| 89 |
t, i, d = u
|
| 90 |
if t == NODE_UNIT: # node
|
|
|
|
| 100 |
token_sum += d.get("length", 0)
|
| 101 |
return True
|
| 102 |
|
| 103 |
+
_add_unit(seed_unit)
|
| 104 |
+
queue.append(seed_unit)
|
| 105 |
|
| 106 |
# BFS
|
| 107 |
+
while queue:
|
| 108 |
if (
|
| 109 |
len(community_nodes) + len(community_edges)
|
| 110 |
>= max_units_per_community
|
|
|
|
| 112 |
):
|
| 113 |
break
|
| 114 |
|
| 115 |
+
cur_type, cur_id, _ = queue.popleft()
|
| 116 |
|
| 117 |
neighbors: List[Tuple[str, Any, dict]] = []
|
| 118 |
if cur_type == NODE_UNIT:
|
|
|
|
| 133 |
or token_sum >= max_tokens_per_community
|
| 134 |
):
|
| 135 |
break
|
| 136 |
+
if _add_unit(nb):
|
| 137 |
+
queue.append(nb)
|
| 138 |
|
| 139 |
if len(community_nodes) + len(community_edges) < min_units_per_community:
|
| 140 |
return None
|
| 141 |
|
| 142 |
return Community(
|
| 143 |
+
id=seed_unit[1],
|
| 144 |
nodes=list(community_nodes.keys()),
|
| 145 |
edges=[(u, v) for (u, v), _ in community_edges.items()],
|
| 146 |
)
|
| 147 |
|
| 148 |
+
for unit in tqdm(all_units, desc="ECE partition"):
|
| 149 |
utype, uid, _ = unit
|
| 150 |
if (utype == NODE_UNIT and uid in used_n) or (
|
| 151 |
utype == EDGE_UNIT and uid in used_e
|
| 152 |
):
|
| 153 |
continue
|
| 154 |
+
comm = _grow_community(unit)
|
| 155 |
+
if comm:
|
| 156 |
+
yield comm
|
|
|
|
|
|
graphgen/models/partitioner/leiden_partitioner.py
CHANGED
|
@@ -13,7 +13,7 @@ class LeidenPartitioner(BasePartitioner):
|
|
| 13 |
Leiden partitioner that partitions the graph into communities using the Leiden algorithm.
|
| 14 |
"""
|
| 15 |
|
| 16 |
-
|
| 17 |
self,
|
| 18 |
g: BaseGraphStorage,
|
| 19 |
max_size: int = 20,
|
|
@@ -37,12 +37,10 @@ class LeidenPartitioner(BasePartitioner):
|
|
| 37 |
nodes = g.get_all_nodes() # List[Tuple[str, dict]]
|
| 38 |
edges = g.get_all_edges() # List[Tuple[str, str, dict]]
|
| 39 |
|
| 40 |
-
node2cid: Dict[str, int] =
|
| 41 |
-
nodes, edges, use_lcc, random_seed
|
| 42 |
-
)
|
| 43 |
|
| 44 |
if max_size is not None and max_size > 0:
|
| 45 |
-
node2cid =
|
| 46 |
|
| 47 |
cid2nodes: Dict[int, List[str]] = defaultdict(list)
|
| 48 |
for n, cid in node2cid.items():
|
|
@@ -58,7 +56,7 @@ class LeidenPartitioner(BasePartitioner):
|
|
| 58 |
return communities
|
| 59 |
|
| 60 |
@staticmethod
|
| 61 |
-
|
| 62 |
nodes: List[Tuple[str, dict]],
|
| 63 |
edges: List[Tuple[str, str, dict]],
|
| 64 |
use_lcc: bool = False,
|
|
@@ -92,9 +90,7 @@ class LeidenPartitioner(BasePartitioner):
|
|
| 92 |
return node2cid
|
| 93 |
|
| 94 |
@staticmethod
|
| 95 |
-
|
| 96 |
-
node2cid: Dict[str, int], max_size: int
|
| 97 |
-
) -> Dict[str, int]:
|
| 98 |
"""
|
| 99 |
Split communities larger than max_size into smaller sub-communities.
|
| 100 |
"""
|
|
|
|
| 13 |
Leiden partitioner that partitions the graph into communities using the Leiden algorithm.
|
| 14 |
"""
|
| 15 |
|
| 16 |
+
def partition(
|
| 17 |
self,
|
| 18 |
g: BaseGraphStorage,
|
| 19 |
max_size: int = 20,
|
|
|
|
| 37 |
nodes = g.get_all_nodes() # List[Tuple[str, dict]]
|
| 38 |
edges = g.get_all_edges() # List[Tuple[str, str, dict]]
|
| 39 |
|
| 40 |
+
node2cid: Dict[str, int] = self._run_leiden(nodes, edges, use_lcc, random_seed)
|
|
|
|
|
|
|
| 41 |
|
| 42 |
if max_size is not None and max_size > 0:
|
| 43 |
+
node2cid = self._split_communities(node2cid, max_size)
|
| 44 |
|
| 45 |
cid2nodes: Dict[int, List[str]] = defaultdict(list)
|
| 46 |
for n, cid in node2cid.items():
|
|
|
|
| 56 |
return communities
|
| 57 |
|
| 58 |
@staticmethod
|
| 59 |
+
def _run_leiden(
|
| 60 |
nodes: List[Tuple[str, dict]],
|
| 61 |
edges: List[Tuple[str, str, dict]],
|
| 62 |
use_lcc: bool = False,
|
|
|
|
| 90 |
return node2cid
|
| 91 |
|
| 92 |
@staticmethod
|
| 93 |
+
def _split_communities(node2cid: Dict[str, int], max_size: int) -> Dict[str, int]:
|
|
|
|
|
|
|
| 94 |
"""
|
| 95 |
Split communities larger than max_size into smaller sub-communities.
|
| 96 |
"""
|
graphgen/models/reader/__init__.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from .csv_reader import CSVReader
|
| 2 |
from .json_reader import JSONReader
|
| 3 |
-
from .jsonl_reader import JSONLReader
|
| 4 |
from .parquet_reader import ParquetReader
|
| 5 |
from .pdf_reader import PDFReader
|
| 6 |
from .pickle_reader import PickleReader
|
|
|
|
| 1 |
from .csv_reader import CSVReader
|
| 2 |
from .json_reader import JSONReader
|
|
|
|
| 3 |
from .parquet_reader import ParquetReader
|
| 4 |
from .pdf_reader import PDFReader
|
| 5 |
from .pickle_reader import PickleReader
|
graphgen/models/reader/csv_reader.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
-
from typing import
|
| 2 |
|
| 3 |
-
import
|
|
|
|
| 4 |
|
| 5 |
from graphgen.bases.base_reader import BaseReader
|
| 6 |
|
|
@@ -13,13 +14,15 @@ class CSVReader(BaseReader):
|
|
| 13 |
- if type is "text", "content" column must be present.
|
| 14 |
"""
|
| 15 |
|
| 16 |
-
def read(self,
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
return
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
|
| 3 |
+
import ray
|
| 4 |
+
from ray.data import Dataset
|
| 5 |
|
| 6 |
from graphgen.bases.base_reader import BaseReader
|
| 7 |
|
|
|
|
| 14 |
- if type is "text", "content" column must be present.
|
| 15 |
"""
|
| 16 |
|
| 17 |
+
def read(self, input_path: Union[str, List[str]]) -> Dataset:
|
| 18 |
+
"""
|
| 19 |
+
Read CSV files and return Ray Dataset.
|
| 20 |
|
| 21 |
+
:param input_path: Path to CSV file or list of CSV files.
|
| 22 |
+
:return: Ray Dataset containing validated and filtered data.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
ds = ray.data.read_csv(input_path)
|
| 26 |
+
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
|
| 27 |
+
ds = ds.filter(self._should_keep_item)
|
| 28 |
+
return ds
|
graphgen/models/reader/json_reader.py
CHANGED
|
@@ -1,26 +1,53 @@
|
|
| 1 |
import json
|
| 2 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases.base_reader import BaseReader
|
| 5 |
|
| 6 |
|
| 7 |
class JSONReader(BaseReader):
|
| 8 |
"""
|
| 9 |
-
Reader for JSON files.
|
| 10 |
Columns:
|
| 11 |
- type: The type of the document (e.g., "text", "image", etc.)
|
| 12 |
- if type is "text", "content" column must be present.
|
| 13 |
"""
|
| 14 |
|
| 15 |
-
def read(self,
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
import ray
|
| 5 |
+
import ray.data
|
| 6 |
|
| 7 |
from graphgen.bases.base_reader import BaseReader
|
| 8 |
|
| 9 |
|
| 10 |
class JSONReader(BaseReader):
|
| 11 |
"""
|
| 12 |
+
Reader for JSON and JSONL files.
|
| 13 |
Columns:
|
| 14 |
- type: The type of the document (e.g., "text", "image", etc.)
|
| 15 |
- if type is "text", "content" column must be present.
|
| 16 |
"""
|
| 17 |
|
| 18 |
+
def read(self, input_path: Union[str, List[str]]) -> ray.data.Dataset:
|
| 19 |
+
"""
|
| 20 |
+
Read JSON file and return Ray Dataset.
|
| 21 |
+
:param input_path: Path to JSON/JSONL file or list of JSON/JSONL files.
|
| 22 |
+
:return: Ray Dataset containing validated and filtered data.
|
| 23 |
+
"""
|
| 24 |
+
if self.modalities and len(self.modalities) >= 2:
|
| 25 |
+
ds: ray.data.Dataset = ray.data.from_items([])
|
| 26 |
+
for file in input_path if isinstance(input_path, list) else [input_path]:
|
| 27 |
+
data = []
|
| 28 |
+
if file.endswith(".jsonl"):
|
| 29 |
+
with open(file, "r", encoding="utf-8") as f:
|
| 30 |
+
for line in f:
|
| 31 |
+
item = json.loads(line)
|
| 32 |
+
data.append(item)
|
| 33 |
+
else:
|
| 34 |
+
with open(file, "r", encoding="utf-8") as f:
|
| 35 |
+
data = json.load(f)
|
| 36 |
+
data = self._unify_schema(data)
|
| 37 |
+
file_ds: ray.data.Dataset = ray.data.from_items(data)
|
| 38 |
+
ds = ds.union(file_ds) # type: ignore
|
| 39 |
+
else:
|
| 40 |
+
ds = ray.data.read_json(input_path)
|
| 41 |
+
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
|
| 42 |
+
ds = ds.filter(self._should_keep_item)
|
| 43 |
+
return ds
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def _unify_schema(data):
|
| 47 |
+
"""
|
| 48 |
+
Unify schema for JSON data.
|
| 49 |
+
"""
|
| 50 |
+
for item in data:
|
| 51 |
+
if "content" in item and isinstance(item["content"], dict):
|
| 52 |
+
item["content"] = json.dumps(item["content"])
|
| 53 |
+
return data
|
graphgen/models/reader/jsonl_reader.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
from typing import Any, Dict, List
|
| 3 |
-
|
| 4 |
-
from graphgen.bases.base_reader import BaseReader
|
| 5 |
-
from graphgen.utils import logger
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class JSONLReader(BaseReader):
|
| 9 |
-
"""
|
| 10 |
-
Reader for JSONL files.
|
| 11 |
-
Columns:
|
| 12 |
-
- type: The type of the document (e.g., "text", "image", etc.)
|
| 13 |
-
- if type is "text", "content" column must be present.
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
def read(self, file_path: str) -> List[Dict[str, Any]]:
|
| 17 |
-
docs = []
|
| 18 |
-
with open(file_path, "r", encoding="utf-8") as f:
|
| 19 |
-
for line in f:
|
| 20 |
-
try:
|
| 21 |
-
doc = json.loads(line)
|
| 22 |
-
assert "type" in doc, f"Missing 'type' in document: {doc}"
|
| 23 |
-
if doc.get("type") == "text" and self.text_column not in doc:
|
| 24 |
-
raise ValueError(
|
| 25 |
-
f"Missing '{self.text_column}' in document: {doc}"
|
| 26 |
-
)
|
| 27 |
-
docs.append(doc)
|
| 28 |
-
except json.JSONDecodeError as e:
|
| 29 |
-
logger.error("Error decoding JSON line: %s. Error: %s", line, e)
|
| 30 |
-
return self.filter(docs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graphgen/models/reader/parquet_reader.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
-
from typing import
|
| 2 |
|
| 3 |
-
import
|
|
|
|
| 4 |
|
| 5 |
from graphgen.bases.base_reader import BaseReader
|
| 6 |
|
|
@@ -13,12 +14,17 @@ class ParquetReader(BaseReader):
|
|
| 13 |
- if type is "text", "content" column must be present.
|
| 14 |
"""
|
| 15 |
|
| 16 |
-
def read(self,
|
| 17 |
-
|
| 18 |
-
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
|
| 3 |
+
import ray
|
| 4 |
+
from ray.data import Dataset
|
| 5 |
|
| 6 |
from graphgen.bases.base_reader import BaseReader
|
| 7 |
|
|
|
|
| 14 |
- if type is "text", "content" column must be present.
|
| 15 |
"""
|
| 16 |
|
| 17 |
+
def read(self, input_path: Union[str, List[str]]) -> Dataset:
|
| 18 |
+
"""
|
| 19 |
+
Read Parquet files using Ray Data.
|
| 20 |
|
| 21 |
+
:param input_path: Path to Parquet file or list of Parquet files.
|
| 22 |
+
:return: Ray Dataset containing validated documents.
|
| 23 |
+
"""
|
| 24 |
+
if not ray.is_initialized():
|
| 25 |
+
ray.init()
|
| 26 |
+
|
| 27 |
+
ds = ray.data.read_parquet(input_path)
|
| 28 |
+
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
|
| 29 |
+
ds = ds.filter(self._should_keep_item)
|
| 30 |
+
return ds
|
graphgen/models/reader/pdf_reader.py
CHANGED
|
@@ -5,6 +5,9 @@ import tempfile
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Any, Dict, List, Optional, Union
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
from graphgen.bases.base_reader import BaseReader
|
| 9 |
from graphgen.models.reader.txt_reader import TXTReader
|
| 10 |
from graphgen.utils import logger, pick_device
|
|
@@ -62,19 +65,31 @@ class PDFReader(BaseReader):
|
|
| 62 |
self.parser = MinerUParser()
|
| 63 |
self.txt_reader = TXTReader()
|
| 64 |
|
| 65 |
-
def read(
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
**override
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
-
|
| 77 |
-
return self.filter(mineru_result)
|
| 78 |
|
| 79 |
def _call_mineru(
|
| 80 |
self, pdf_path: Path, kwargs: Dict[str, Any]
|
|
@@ -161,18 +176,18 @@ class MinerUParser:
|
|
| 161 |
|
| 162 |
base = os.path.dirname(json_file)
|
| 163 |
results = []
|
| 164 |
-
for
|
| 165 |
for key in ("img_path", "table_img_path", "equation_img_path"):
|
| 166 |
-
rel_path =
|
| 167 |
if rel_path:
|
| 168 |
-
|
| 169 |
-
if
|
| 170 |
-
|
| 171 |
-
del
|
| 172 |
for key in ("page_idx", "bbox", "text_level"):
|
| 173 |
-
if
|
| 174 |
-
del
|
| 175 |
-
results.append(
|
| 176 |
return results
|
| 177 |
|
| 178 |
@staticmethod
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Any, Dict, List, Optional, Union
|
| 7 |
|
| 8 |
+
import ray
|
| 9 |
+
from ray.data import Dataset
|
| 10 |
+
|
| 11 |
from graphgen.bases.base_reader import BaseReader
|
| 12 |
from graphgen.models.reader.txt_reader import TXTReader
|
| 13 |
from graphgen.utils import logger, pick_device
|
|
|
|
| 65 |
self.parser = MinerUParser()
|
| 66 |
self.txt_reader = TXTReader()
|
| 67 |
|
| 68 |
+
def read(
|
| 69 |
+
self,
|
| 70 |
+
input_path: Union[str, List[str]],
|
| 71 |
+
**override,
|
| 72 |
+
) -> Dataset:
|
| 73 |
+
|
| 74 |
+
# Ensure input_path is a list
|
| 75 |
+
if isinstance(input_path, str):
|
| 76 |
+
input_path = [input_path]
|
| 77 |
+
|
| 78 |
+
paths_ds = ray.data.from_items(input_path)
|
| 79 |
+
|
| 80 |
+
def process_pdf(row: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 81 |
+
try:
|
| 82 |
+
pdf_path = row["item"]
|
| 83 |
+
kwargs = {**self._default_kwargs, **override}
|
| 84 |
+
return self._call_mineru(Path(pdf_path), kwargs)
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error("Failed to process %s: %s", row, e)
|
| 87 |
+
return []
|
| 88 |
|
| 89 |
+
docs_ds = paths_ds.flat_map(process_pdf)
|
| 90 |
+
docs_ds = docs_ds.filter(self._should_keep_item)
|
| 91 |
|
| 92 |
+
return docs_ds
|
|
|
|
| 93 |
|
| 94 |
def _call_mineru(
|
| 95 |
self, pdf_path: Path, kwargs: Dict[str, Any]
|
|
|
|
| 176 |
|
| 177 |
base = os.path.dirname(json_file)
|
| 178 |
results = []
|
| 179 |
+
for it in data:
|
| 180 |
for key in ("img_path", "table_img_path", "equation_img_path"):
|
| 181 |
+
rel_path = it.get(key)
|
| 182 |
if rel_path:
|
| 183 |
+
it[key] = str(Path(base).joinpath(rel_path).resolve())
|
| 184 |
+
if it["type"] == "text":
|
| 185 |
+
it["content"] = it["text"]
|
| 186 |
+
del it["text"]
|
| 187 |
for key in ("page_idx", "bbox", "text_level"):
|
| 188 |
+
if it.get(key) is not None:
|
| 189 |
+
del it[key]
|
| 190 |
+
results.append(it)
|
| 191 |
return results
|
| 192 |
|
| 193 |
@staticmethod
|
graphgen/models/reader/pickle_reader.py
CHANGED
|
@@ -1,30 +1,78 @@
|
|
| 1 |
import pickle
|
| 2 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from graphgen.bases.base_reader import BaseReader
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class PickleReader(BaseReader):
|
| 8 |
"""
|
| 9 |
-
Read pickle files, requiring the
|
| 10 |
-
|
| 11 |
-
Columns:
|
| 12 |
- type: The type of the document (e.g., "text", "image", etc.)
|
| 13 |
- if type is "text", "content" column must be present.
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
-
def read(
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
raise ValueError("Every item in the list must be a dict.")
|
| 26 |
-
assert "type" in doc, f"Missing 'type' in document: {doc}"
|
| 27 |
-
if doc.get("type") == "text" and self.text_column not in doc:
|
| 28 |
-
raise ValueError(f"Missing '{self.text_column}' in document: {doc}")
|
| 29 |
|
| 30 |
-
|
|
|
|
|
|
|
|
|
| 1 |
import pickle
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import ray
|
| 6 |
+
from ray.data import Dataset
|
| 7 |
|
| 8 |
from graphgen.bases.base_reader import BaseReader
|
| 9 |
+
from graphgen.utils import logger
|
| 10 |
|
| 11 |
|
| 12 |
class PickleReader(BaseReader):
|
| 13 |
"""
|
| 14 |
+
Read pickle files, requiring the schema to be restored to List[Dict[str, Any]].
|
| 15 |
+
Each pickle file should contain a list of dictionaries with at least:
|
|
|
|
| 16 |
- type: The type of the document (e.g., "text", "image", etc.)
|
| 17 |
- if type is "text", "content" column must be present.
|
| 18 |
+
|
| 19 |
+
Note: Uses ray.data.read_binary_files as ray.data.read_pickle is not available.
|
| 20 |
+
For Ray >= 2.5, consider using read_pickle if available in your version.
|
| 21 |
"""
|
| 22 |
|
| 23 |
+
def read(
|
| 24 |
+
self,
|
| 25 |
+
input_path: Union[str, List[str]],
|
| 26 |
+
) -> Dataset:
|
| 27 |
+
"""
|
| 28 |
+
Read Pickle files using Ray Data.
|
| 29 |
+
|
| 30 |
+
:param input_path: Path to pickle file or list of pickle files.
|
| 31 |
+
:return: Ray Dataset containing validated documents.
|
| 32 |
+
"""
|
| 33 |
+
if not ray.is_initialized():
|
| 34 |
+
ray.init()
|
| 35 |
+
|
| 36 |
+
# Use read_binary_files as a reliable alternative to read_pickle
|
| 37 |
+
ds = ray.data.read_binary_files(input_path, include_paths=True)
|
| 38 |
+
|
| 39 |
+
# Deserialize pickle files and flatten into individual records
|
| 40 |
+
def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame:
|
| 41 |
+
all_records = []
|
| 42 |
+
for _, row in batch.iterrows():
|
| 43 |
+
try:
|
| 44 |
+
# Load pickle data from bytes
|
| 45 |
+
data = pickle.loads(row["bytes"])
|
| 46 |
+
|
| 47 |
+
# Validate structure
|
| 48 |
+
if not isinstance(data, list):
|
| 49 |
+
logger.error(
|
| 50 |
+
"Pickle file {row['path']} must contain a list, got {type(data)}"
|
| 51 |
+
)
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
if not all(isinstance(item, dict) for item in data):
|
| 55 |
+
logger.error(
|
| 56 |
+
"Pickle file {row['path']} must contain a list of dictionaries"
|
| 57 |
+
)
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
# Flatten: each dict in the list becomes a separate row
|
| 61 |
+
all_records.extend(data)
|
| 62 |
+
except Exception as e:
|
| 63 |
+
logger.error(
|
| 64 |
+
"Failed to deserialize pickle file %s: %s", row["path"], str(e)
|
| 65 |
+
)
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
return pd.DataFrame(all_records)
|
| 69 |
|
| 70 |
+
# Apply deserialization and flattening
|
| 71 |
+
ds = ds.map_batches(deserialize_batch, batch_format="pandas")
|
| 72 |
|
| 73 |
+
# Validate the schema
|
| 74 |
+
ds = ds.map_batches(self._validate_batch, batch_format="pandas")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
+
# Filter valid items
|
| 77 |
+
ds = ds.filter(self._should_keep_item)
|
| 78 |
+
return ds
|
graphgen/models/reader/rdf_reader.py
CHANGED
|
@@ -1,48 +1,128 @@
|
|
| 1 |
-
from
|
|
|
|
| 2 |
|
|
|
|
| 3 |
import rdflib
|
|
|
|
| 4 |
from rdflib import Literal
|
| 5 |
from rdflib.util import guess_format
|
| 6 |
|
| 7 |
from graphgen.bases.base_reader import BaseReader
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class RDFReader(BaseReader):
|
| 11 |
"""
|
| 12 |
Reader for RDF files that extracts triples and represents them as dictionaries.
|
|
|
|
|
|
|
| 13 |
"""
|
| 14 |
|
| 15 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
g = rdflib.Graph()
|
| 17 |
-
fmt = guess_format(file_path)
|
|
|
|
| 18 |
try:
|
| 19 |
-
g.parse(file_path, format=fmt)
|
| 20 |
except Exception as e:
|
| 21 |
raise ValueError(f"Cannot parse RDF file {file_path}: {e}") from e
|
| 22 |
|
| 23 |
docs: List[Dict[str, Any]] = []
|
| 24 |
-
text_col = self.text_column
|
| 25 |
|
|
|
|
| 26 |
for subj in set(g.subjects()):
|
| 27 |
literals = []
|
| 28 |
props = {}
|
|
|
|
|
|
|
| 29 |
for _, pred, obj in g.triples((subj, None, None)):
|
| 30 |
pred_str = str(pred)
|
|
|
|
|
|
|
|
|
|
| 31 |
if isinstance(obj, Literal):
|
| 32 |
-
literals.append(
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
|
|
|
|
| 35 |
text = " ".join(literals).strip()
|
| 36 |
if not text:
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
)
|
| 41 |
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
docs.append(doc)
|
| 44 |
|
| 45 |
if not docs:
|
| 46 |
-
|
| 47 |
|
| 48 |
-
return
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Any, Dict, List, Union
|
| 3 |
|
| 4 |
+
import ray
|
| 5 |
import rdflib
|
| 6 |
+
from ray.data import Dataset
|
| 7 |
from rdflib import Literal
|
| 8 |
from rdflib.util import guess_format
|
| 9 |
|
| 10 |
from graphgen.bases.base_reader import BaseReader
|
| 11 |
+
from graphgen.utils import logger
|
| 12 |
|
| 13 |
|
| 14 |
class RDFReader(BaseReader):
|
| 15 |
"""
|
| 16 |
Reader for RDF files that extracts triples and represents them as dictionaries.
|
| 17 |
+
|
| 18 |
+
Uses Ray Data for distributed processing of multiple RDF files.
|
| 19 |
"""
|
| 20 |
|
| 21 |
+
def __init__(self, *, text_column: str = "content", **kwargs):
|
| 22 |
+
"""
|
| 23 |
+
Initialize RDFReader.
|
| 24 |
+
|
| 25 |
+
:param text_column: The column name for text content (default: "content").
|
| 26 |
+
"""
|
| 27 |
+
super().__init__(**kwargs)
|
| 28 |
+
self.text_column = text_column
|
| 29 |
+
|
| 30 |
+
def read(
|
| 31 |
+
self,
|
| 32 |
+
input_path: Union[str, List[str]],
|
| 33 |
+
) -> Dataset:
|
| 34 |
+
"""
|
| 35 |
+
Read RDF file(s) using Ray Data.
|
| 36 |
+
|
| 37 |
+
:param input_path: Path to RDF file or list of RDF files.
|
| 38 |
+
:return: Ray Dataset containing extracted documents.
|
| 39 |
+
"""
|
| 40 |
+
if not ray.is_initialized():
|
| 41 |
+
ray.init()
|
| 42 |
+
|
| 43 |
+
# Ensure input_path is a list to prevent Ray from splitting string into characters
|
| 44 |
+
if isinstance(input_path, str):
|
| 45 |
+
input_path = [input_path]
|
| 46 |
+
|
| 47 |
+
# Create dataset from file paths
|
| 48 |
+
paths_ds = ray.data.from_items(input_path)
|
| 49 |
+
|
| 50 |
+
def process_rdf(row: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| 51 |
+
"""Process a single RDF file and return list of documents."""
|
| 52 |
+
try:
|
| 53 |
+
file_path = row["item"]
|
| 54 |
+
return self._parse_rdf_file(Path(file_path))
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.error(
|
| 57 |
+
"Failed to process RDF file %s: %s", row.get("item", "unknown"), e
|
| 58 |
+
)
|
| 59 |
+
return []
|
| 60 |
+
|
| 61 |
+
# Process files in parallel and flatten results
|
| 62 |
+
docs_ds = paths_ds.flat_map(process_rdf)
|
| 63 |
+
|
| 64 |
+
# Filter valid documents
|
| 65 |
+
docs_ds = docs_ds.filter(self._should_keep_item)
|
| 66 |
+
|
| 67 |
+
return docs_ds
|
| 68 |
+
|
| 69 |
+
def _parse_rdf_file(self, file_path: Path) -> List[Dict[str, Any]]:
|
| 70 |
+
"""
|
| 71 |
+
Parse a single RDF file and extract documents.
|
| 72 |
+
|
| 73 |
+
:param file_path: Path to RDF file.
|
| 74 |
+
:return: List of document dictionaries.
|
| 75 |
+
"""
|
| 76 |
+
if not file_path.is_file():
|
| 77 |
+
raise FileNotFoundError(f"RDF file not found: {file_path}")
|
| 78 |
+
|
| 79 |
g = rdflib.Graph()
|
| 80 |
+
fmt = guess_format(str(file_path))
|
| 81 |
+
|
| 82 |
try:
|
| 83 |
+
g.parse(str(file_path), format=fmt)
|
| 84 |
except Exception as e:
|
| 85 |
raise ValueError(f"Cannot parse RDF file {file_path}: {e}") from e
|
| 86 |
|
| 87 |
docs: List[Dict[str, Any]] = []
|
|
|
|
| 88 |
|
| 89 |
+
# Process each unique subject in the RDF graph
|
| 90 |
for subj in set(g.subjects()):
|
| 91 |
literals = []
|
| 92 |
props = {}
|
| 93 |
+
|
| 94 |
+
# Extract all triples for this subject
|
| 95 |
for _, pred, obj in g.triples((subj, None, None)):
|
| 96 |
pred_str = str(pred)
|
| 97 |
+
obj_str = str(obj)
|
| 98 |
+
|
| 99 |
+
# Collect literal values as text content
|
| 100 |
if isinstance(obj, Literal):
|
| 101 |
+
literals.append(obj_str)
|
| 102 |
+
|
| 103 |
+
# Store all properties (including non-literals)
|
| 104 |
+
props.setdefault(pred_str, []).append(obj_str)
|
| 105 |
|
| 106 |
+
# Join all literal values as the text content
|
| 107 |
text = " ".join(literals).strip()
|
| 108 |
if not text:
|
| 109 |
+
logger.warning(
|
| 110 |
+
"Subject %s in %s has no literal values; document will have empty '%s' field.",
|
| 111 |
+
subj,
|
| 112 |
+
file_path,
|
| 113 |
+
self.text_column,
|
| 114 |
)
|
| 115 |
|
| 116 |
+
# Create document dictionary
|
| 117 |
+
doc = {
|
| 118 |
+
"id": str(subj),
|
| 119 |
+
self.text_column: text,
|
| 120 |
+
"properties": props,
|
| 121 |
+
"source_file": str(file_path),
|
| 122 |
+
}
|
| 123 |
docs.append(doc)
|
| 124 |
|
| 125 |
if not docs:
|
| 126 |
+
logger.warning("RDF file %s contains no valid documents.", file_path)
|
| 127 |
|
| 128 |
+
return docs
|
graphgen/models/reader/txt_reader.py
CHANGED
|
@@ -1,10 +1,32 @@
|
|
| 1 |
-
from typing import
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from graphgen.bases.base_reader import BaseReader
|
| 4 |
|
| 5 |
|
| 6 |
class TXTReader(BaseReader):
|
| 7 |
-
def read(
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
|
| 3 |
+
import ray
|
| 4 |
+
from ray.data import Dataset
|
| 5 |
|
| 6 |
from graphgen.bases.base_reader import BaseReader
|
| 7 |
|
| 8 |
|
| 9 |
class TXTReader(BaseReader):
|
| 10 |
+
def read(
|
| 11 |
+
self,
|
| 12 |
+
input_path: Union[str, List[str]],
|
| 13 |
+
) -> Dataset:
|
| 14 |
+
"""
|
| 15 |
+
Read text files from the specified input path.
|
| 16 |
+
:param input_path: Path to the input text file or list of text files.
|
| 17 |
+
:return: Ray Dataset containing the read text data.
|
| 18 |
+
"""
|
| 19 |
+
docs_ds = ray.data.read_binary_files(
|
| 20 |
+
input_path,
|
| 21 |
+
include_paths=False,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
docs_ds = docs_ds.map(
|
| 25 |
+
lambda row: {
|
| 26 |
+
"type": "text",
|
| 27 |
+
self.text_column: row["bytes"].decode("utf-8"),
|
| 28 |
+
}
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
docs_ds = docs_ds.filter(self._should_keep_item)
|
| 32 |
+
return docs_ds
|
graphgen/models/splitter/character_splitter.py
CHANGED
|
@@ -17,7 +17,7 @@ class CharacterSplitter(BaseSplitter):
|
|
| 17 |
|
| 18 |
def split_text(self, text: str) -> List[str]:
|
| 19 |
"""Split incoming text and return chunks."""
|
| 20 |
-
# First we naively
|
| 21 |
separator = (
|
| 22 |
self._separator if self._is_separator_regex else re.escape(self._separator)
|
| 23 |
)
|
|
|
|
| 17 |
|
| 18 |
def split_text(self, text: str) -> List[str]:
|
| 19 |
"""Split incoming text and return chunks."""
|
| 20 |
+
# First we naively chunk the large input into a bunch of smaller ones.
|
| 21 |
separator = (
|
| 22 |
self._separator if self._is_separator_regex else re.escape(self._separator)
|
| 23 |
)
|
graphgen/models/splitter/markdown_splitter.py
CHANGED
|
@@ -6,12 +6,12 @@ from graphgen.models.splitter.recursive_character_splitter import (
|
|
| 6 |
|
| 7 |
|
| 8 |
class MarkdownTextRefSplitter(RecursiveCharacterSplitter):
|
| 9 |
-
"""Attempts to
|
| 10 |
|
| 11 |
def __init__(self, **kwargs: Any) -> None:
|
| 12 |
"""Initialize a MarkdownTextRefSplitter."""
|
| 13 |
separators = [
|
| 14 |
-
# First, try to
|
| 15 |
"\n#{1,6} ",
|
| 16 |
# Note the alternative syntax for headings (below) is not handled here
|
| 17 |
# Heading level 2
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class MarkdownTextRefSplitter(RecursiveCharacterSplitter):
|
| 9 |
+
"""Attempts to chunk the text along Markdown-formatted headings."""
|
| 10 |
|
| 11 |
def __init__(self, **kwargs: Any) -> None:
|
| 12 |
"""Initialize a MarkdownTextRefSplitter."""
|
| 13 |
separators = [
|
| 14 |
+
# First, try to chunk along Markdown headings (starting with level 2)
|
| 15 |
"\n#{1,6} ",
|
| 16 |
# Note the alternative syntax for headings (below) is not handled here
|
| 17 |
# Heading level 2
|
graphgen/models/splitter/recursive_character_splitter.py
CHANGED
|
@@ -7,7 +7,7 @@ from graphgen.bases.base_splitter import BaseSplitter
|
|
| 7 |
class RecursiveCharacterSplitter(BaseSplitter):
|
| 8 |
"""Splitting text by recursively look at characters.
|
| 9 |
|
| 10 |
-
Recursively tries to
|
| 11 |
"""
|
| 12 |
|
| 13 |
def __init__(
|
|
@@ -88,7 +88,7 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterSplitter):
|
|
| 88 |
def _split_text_with_regex_from_end(
|
| 89 |
self, text: str, separator: str, keep_separator: bool
|
| 90 |
) -> List[str]:
|
| 91 |
-
# Now that we have the separator,
|
| 92 |
if separator:
|
| 93 |
if keep_separator:
|
| 94 |
# The parentheses in the pattern keep the delimiters in the result.
|
|
|
|
| 7 |
class RecursiveCharacterSplitter(BaseSplitter):
|
| 8 |
"""Splitting text by recursively look at characters.
|
| 9 |
|
| 10 |
+
Recursively tries to chunk by different characters to find one that works.
|
| 11 |
"""
|
| 12 |
|
| 13 |
def __init__(
|
|
|
|
| 88 |
def _split_text_with_regex_from_end(
|
| 89 |
self, text: str, separator: str, keep_separator: bool
|
| 90 |
) -> List[str]:
|
| 91 |
+
# Now that we have the separator, chunk the text
|
| 92 |
if separator:
|
| 93 |
if keep_separator:
|
| 94 |
# The parentheses in the pattern keep the delimiters in the result.
|
graphgen/models/storage/__init__.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
| 1 |
-
from .
|
| 2 |
-
from .networkx_storage import NetworkXStorage
|
|
|
|
|
|
|
|
|
|
| 3 |
from .rocksdb_cache import RocksDBCache
|
|
|
|
| 1 |
+
from graphgen.models.storage.graph.kuzu_storage import KuzuStorage
|
| 2 |
+
from graphgen.models.storage.graph.networkx_storage import NetworkXStorage
|
| 3 |
+
from graphgen.models.storage.kv.json_storage import JsonKVStorage
|
| 4 |
+
from graphgen.models.storage.kv.rocksdb_storage import RocksDBKVStorage
|
| 5 |
+
|
| 6 |
from .rocksdb_cache import RocksDBCache
|
graphgen/{configs → models/storage/graph}/__init__.py
RENAMED
|
File without changes
|
graphgen/models/storage/graph/kuzu_storage.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
import kuzu
|
| 9 |
+
except ImportError:
|
| 10 |
+
kuzu = None
|
| 11 |
+
|
| 12 |
+
from graphgen.bases.base_storage import BaseGraphStorage
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class KuzuStorage(BaseGraphStorage):
|
| 17 |
+
"""
|
| 18 |
+
Graph storage implementation based on KuzuDB.
|
| 19 |
+
Since KuzuDB is a structured graph database and GraphGen uses dynamic dictionaries for properties,
|
| 20 |
+
we map the data to a generic schema:
|
| 21 |
+
- Node Table 'Entity': {id: STRING, data: STRING (JSON)}
|
| 22 |
+
- Rel Table 'Relation': {FROM Entity TO Entity, data: STRING (JSON)}
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
working_dir: str = None
|
| 26 |
+
namespace: str = None
|
| 27 |
+
_db: Any = None
|
| 28 |
+
_conn: Any = None
|
| 29 |
+
|
| 30 |
+
def __post_init__(self):
|
| 31 |
+
if kuzu is None:
|
| 32 |
+
raise ImportError(
|
| 33 |
+
"KuzuDB is not installed. Please install it via `pip install kuzu`."
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
self.db_path = os.path.join(self.working_dir, f"{self.namespace}_kuzu")
|
| 37 |
+
self._init_db()
|
| 38 |
+
|
| 39 |
+
def _init_db(self):
|
| 40 |
+
# KuzuDB automatically creates the directory
|
| 41 |
+
self._db = kuzu.Database(self.db_path)
|
| 42 |
+
self._conn = kuzu.Connection(self._db)
|
| 43 |
+
self._init_schema()
|
| 44 |
+
print(f"KuzuDB initialized at {self.db_path}")
|
| 45 |
+
|
| 46 |
+
def _init_schema(self):
|
| 47 |
+
"""Initialize the generic Node and Edge tables if they don't exist."""
|
| 48 |
+
# Check and create Node table
|
| 49 |
+
try:
|
| 50 |
+
# We use a generic table name "Entity" to store all nodes
|
| 51 |
+
self._conn.execute(
|
| 52 |
+
"CREATE NODE TABLE Entity(id STRING, data STRING, PRIMARY KEY(id))"
|
| 53 |
+
)
|
| 54 |
+
print("Created KuzuDB Node Table 'Entity'")
|
| 55 |
+
except RuntimeError as e:
|
| 56 |
+
# Usually throws if table exists, verify safely or ignore
|
| 57 |
+
print("Node Table 'Entity' already exists or error:", e)
|
| 58 |
+
|
| 59 |
+
# Check and create Edge table
|
| 60 |
+
try:
|
| 61 |
+
# We use a generic table name "Relation" to store all edges
|
| 62 |
+
self._conn.execute(
|
| 63 |
+
"CREATE REL TABLE Relation(FROM Entity TO Entity, data STRING)"
|
| 64 |
+
)
|
| 65 |
+
print("Created KuzuDB Rel Table 'Relation'")
|
| 66 |
+
except RuntimeError as e:
|
| 67 |
+
print("Rel Table 'Relation' already exists or error:", e)
|
| 68 |
+
|
| 69 |
+
def index_done_callback(self):
|
| 70 |
+
"""KuzuDB is ACID, changes are immediate, but we can verify generic persistence here."""
|
| 71 |
+
|
| 72 |
+
def has_node(self, node_id: str) -> bool:
|
| 73 |
+
result = self._conn.execute(
|
| 74 |
+
"MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id}
|
| 75 |
+
)
|
| 76 |
+
count = result.get_next()[0]
|
| 77 |
+
return count > 0
|
| 78 |
+
|
| 79 |
+
def has_edge(self, source_node_id: str, target_node_id: str):
|
| 80 |
+
result = self._conn.execute(
|
| 81 |
+
"MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) RETURN count(e)",
|
| 82 |
+
{"src": source_node_id, "dst": target_node_id},
|
| 83 |
+
)
|
| 84 |
+
count = result.get_next()[0]
|
| 85 |
+
return count > 0
|
| 86 |
+
|
| 87 |
+
def node_degree(self, node_id: str) -> int:
|
| 88 |
+
# Calculate total degree (incoming + outgoing)
|
| 89 |
+
query = """
|
| 90 |
+
MATCH (a:Entity {id: $id})-[e:Relation]-(b:Entity)
|
| 91 |
+
RETURN count(e)
|
| 92 |
+
"""
|
| 93 |
+
result = self._conn.execute(query, {"id": node_id})
|
| 94 |
+
if result.has_next():
|
| 95 |
+
return result.get_next()[0]
|
| 96 |
+
return 0
|
| 97 |
+
|
| 98 |
+
def edge_degree(self, src_id: str, tgt_id: str) -> int:
|
| 99 |
+
# In this context, usually checks existence or multiplicity.
|
| 100 |
+
# Kuzu supports multi-edges, so we count them.
|
| 101 |
+
query = """
|
| 102 |
+
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
|
| 103 |
+
RETURN count(e)
|
| 104 |
+
"""
|
| 105 |
+
result = self._conn.execute(query, {"src": src_id, "dst": tgt_id})
|
| 106 |
+
if result.has_next():
|
| 107 |
+
return result.get_next()[0]
|
| 108 |
+
return 0
|
| 109 |
+
|
| 110 |
+
def get_node(self, node_id: str) -> Any:
|
| 111 |
+
result = self._conn.execute(
|
| 112 |
+
"MATCH (a:Entity {id: $id}) RETURN a.data", {"id": node_id}
|
| 113 |
+
)
|
| 114 |
+
if result.has_next():
|
| 115 |
+
data_str = result.get_next()[0]
|
| 116 |
+
return json.loads(data_str) if data_str else {}
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
def update_node(self, node_id: str, node_data: dict[str, str]):
|
| 120 |
+
current_data = self.get_node(node_id)
|
| 121 |
+
if current_data is None:
|
| 122 |
+
print(f"Node {node_id} not found for update.")
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
# Merge existing data with new data
|
| 126 |
+
current_data.update(node_data)
|
| 127 |
+
json_data = json.dumps(current_data, ensure_ascii=False)
|
| 128 |
+
|
| 129 |
+
self._conn.execute(
|
| 130 |
+
"MATCH (a:Entity {id: $id}) SET a.data = $data",
|
| 131 |
+
{"id": node_id, "data": json_data},
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def get_all_nodes(self) -> Any:
|
| 135 |
+
"""Returns List[Tuple[id, data_dict]]"""
|
| 136 |
+
result = self._conn.execute("MATCH (a:Entity) RETURN a.id, a.data")
|
| 137 |
+
nodes = []
|
| 138 |
+
while result.has_next():
|
| 139 |
+
row = result.get_next()
|
| 140 |
+
nodes.append((row[0], json.loads(row[1])))
|
| 141 |
+
return nodes
|
| 142 |
+
|
| 143 |
+
def get_edge(self, source_node_id: str, target_node_id: str):
|
| 144 |
+
# Warning: If multiple edges exist, this returns the first one found
|
| 145 |
+
query = """
|
| 146 |
+
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
|
| 147 |
+
RETURN e.data
|
| 148 |
+
"""
|
| 149 |
+
result = self._conn.execute(
|
| 150 |
+
query, {"src": source_node_id, "dst": target_node_id}
|
| 151 |
+
)
|
| 152 |
+
if result.has_next():
|
| 153 |
+
data_str = result.get_next()[0]
|
| 154 |
+
return json.loads(data_str) if data_str else {}
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
def update_edge(
|
| 158 |
+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 159 |
+
):
|
| 160 |
+
current_data = self.get_edge(source_node_id, target_node_id)
|
| 161 |
+
if current_data is None:
|
| 162 |
+
print(f"Edge {source_node_id}->{target_node_id} not found for update.")
|
| 163 |
+
return
|
| 164 |
+
|
| 165 |
+
current_data.update(edge_data)
|
| 166 |
+
json_data = json.dumps(current_data, ensure_ascii=False)
|
| 167 |
+
|
| 168 |
+
query = """
|
| 169 |
+
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst})
|
| 170 |
+
SET e.data = $data
|
| 171 |
+
"""
|
| 172 |
+
self._conn.execute(
|
| 173 |
+
query, {"src": source_node_id, "dst": target_node_id, "data": json_data}
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def get_all_edges(self) -> Any:
|
| 177 |
+
"""Returns List[Tuple[src, dst, data_dict]]"""
|
| 178 |
+
query = "MATCH (a:Entity)-[e:Relation]->(b:Entity) RETURN a.id, b.id, e.data"
|
| 179 |
+
result = self._conn.execute(query)
|
| 180 |
+
edges = []
|
| 181 |
+
while result.has_next():
|
| 182 |
+
row = result.get_next()
|
| 183 |
+
edges.append((row[0], row[1], json.loads(row[2])))
|
| 184 |
+
return edges
|
| 185 |
+
|
| 186 |
+
def get_node_edges(self, source_node_id: str) -> Any:
|
| 187 |
+
"""Returns generic edges connected to this node (outgoing)"""
|
| 188 |
+
query = """
|
| 189 |
+
MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity)
|
| 190 |
+
RETURN a.id, b.id, e.data
|
| 191 |
+
"""
|
| 192 |
+
result = self._conn.execute(query, {"src": source_node_id})
|
| 193 |
+
edges = []
|
| 194 |
+
while result.has_next():
|
| 195 |
+
row = result.get_next()
|
| 196 |
+
edges.append((row[0], row[1], json.loads(row[2])))
|
| 197 |
+
return edges
|
| 198 |
+
|
| 199 |
+
def upsert_node(self, node_id: str, node_data: dict[str, str]):
|
| 200 |
+
"""
|
| 201 |
+
Insert or Update node.
|
| 202 |
+
Kuzu supports MERGE clause (similar to Neo4j) to handle upserts.
|
| 203 |
+
"""
|
| 204 |
+
json_data = json.dumps(node_data, ensure_ascii=False)
|
| 205 |
+
query = """
|
| 206 |
+
MERGE (a:Entity {id: $id})
|
| 207 |
+
ON MATCH SET a.data = $data
|
| 208 |
+
ON CREATE SET a.data = $data
|
| 209 |
+
"""
|
| 210 |
+
self._conn.execute(query, {"id": node_id, "data": json_data})
|
| 211 |
+
|
| 212 |
+
def upsert_edge(
|
| 213 |
+
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
| 214 |
+
):
|
| 215 |
+
"""
|
| 216 |
+
Insert or Update edge.
|
| 217 |
+
Note: We explicitly ensure nodes exist before merging the edge to avoid errors,
|
| 218 |
+
although GraphGen generally creates nodes before edges.
|
| 219 |
+
"""
|
| 220 |
+
# Ensure source node exists
|
| 221 |
+
if not self.has_node(source_node_id):
|
| 222 |
+
self.upsert_node(source_node_id, {})
|
| 223 |
+
# Ensure target node exists
|
| 224 |
+
if not self.has_node(target_node_id):
|
| 225 |
+
self.upsert_node(target_node_id, {})
|
| 226 |
+
|
| 227 |
+
json_data = json.dumps(edge_data, ensure_ascii=False)
|
| 228 |
+
query = """
|
| 229 |
+
MATCH (a:Entity {id: $src}), (b:Entity {id: $dst})
|
| 230 |
+
MERGE (a)-[e:Relation]->(b)
|
| 231 |
+
ON MATCH SET e.data = $data
|
| 232 |
+
ON CREATE SET e.data = $data
|
| 233 |
+
"""
|
| 234 |
+
self._conn.execute(
|
| 235 |
+
query, {"src": source_node_id, "dst": target_node_id, "data": json_data}
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
def delete_node(self, node_id: str):
|
| 239 |
+
# DETACH DELETE removes the node and all connected edges
|
| 240 |
+
query = "MATCH (a:Entity {id: $id}) DETACH DELETE a"
|
| 241 |
+
self._conn.execute(query, {"id": node_id})
|
| 242 |
+
print(f"Node {node_id} deleted from KuzuDB.")
|
| 243 |
+
|
| 244 |
+
def clear(self):
|
| 245 |
+
"""Clear all data but keep schema (or drop tables)."""
|
| 246 |
+
self._conn.execute("MATCH (n) DETACH DELETE n")
|
| 247 |
+
print(f"Graph {self.namespace} cleared.")
|
| 248 |
+
|
| 249 |
+
def reload(self):
|
| 250 |
+
"""For databases that need reloading, KuzuDB auto-manages this."""
|
| 251 |
+
|
| 252 |
+
def drop(self):
|
| 253 |
+
"""Completely remove the database folder."""
|
| 254 |
+
if self.db_path and os.path.exists(self.db_path):
|
| 255 |
+
shutil.rmtree(self.db_path)
|
| 256 |
+
print(f"Dropped KuzuDB at {self.db_path}")
|
graphgen/models/storage/{networkx_storage.py → graph/networkx_storage.py}
RENAMED
|
@@ -6,7 +6,6 @@ from typing import Any, Optional, Union, cast
|
|
| 6 |
import networkx as nx
|
| 7 |
|
| 8 |
from graphgen.bases.base_storage import BaseGraphStorage
|
| 9 |
-
from graphgen.utils import logger
|
| 10 |
|
| 11 |
|
| 12 |
@dataclass
|
|
@@ -19,11 +18,6 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 19 |
|
| 20 |
@staticmethod
|
| 21 |
def write_nx_graph(graph: nx.Graph, file_name):
|
| 22 |
-
logger.info(
|
| 23 |
-
"Writing graph with %d nodes, %d edges",
|
| 24 |
-
graph.number_of_nodes(),
|
| 25 |
-
graph.number_of_edges(),
|
| 26 |
-
)
|
| 27 |
nx.write_graphml(graph, file_name)
|
| 28 |
|
| 29 |
@staticmethod
|
|
@@ -82,12 +76,11 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 82 |
self.working_dir, f"{self.namespace}.graphml"
|
| 83 |
)
|
| 84 |
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
| 85 |
-
if preloaded_graph
|
| 86 |
-
|
| 87 |
-
"Loaded graph from
|
| 88 |
-
|
| 89 |
-
preloaded_graph.
|
| 90 |
-
preloaded_graph.number_of_edges(),
|
| 91 |
)
|
| 92 |
self._graph = preloaded_graph or nx.Graph()
|
| 93 |
|
|
@@ -133,7 +126,7 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 133 |
if self._graph.has_node(node_id):
|
| 134 |
self._graph.nodes[node_id].update(node_data)
|
| 135 |
else:
|
| 136 |
-
|
| 137 |
|
| 138 |
def upsert_edge(
|
| 139 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
|
@@ -146,10 +139,8 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 146 |
if self._graph.has_edge(source_node_id, target_node_id):
|
| 147 |
self._graph.edges[(source_node_id, target_node_id)].update(edge_data)
|
| 148 |
else:
|
| 149 |
-
|
| 150 |
-
"Edge
|
| 151 |
-
source_node_id,
|
| 152 |
-
target_node_id,
|
| 153 |
)
|
| 154 |
|
| 155 |
def delete_node(self, node_id: str):
|
|
@@ -160,13 +151,19 @@ class NetworkXStorage(BaseGraphStorage):
|
|
| 160 |
"""
|
| 161 |
if self._graph.has_node(node_id):
|
| 162 |
self._graph.remove_node(node_id)
|
| 163 |
-
|
| 164 |
else:
|
| 165 |
-
|
| 166 |
|
| 167 |
def clear(self):
|
| 168 |
"""
|
| 169 |
Clear the graph by removing all nodes and edges.
|
| 170 |
"""
|
| 171 |
self._graph.clear()
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
import networkx as nx
|
| 7 |
|
| 8 |
from graphgen.bases.base_storage import BaseGraphStorage
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
@dataclass
|
|
|
|
| 18 |
|
| 19 |
@staticmethod
|
| 20 |
def write_nx_graph(graph: nx.Graph, file_name):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
nx.write_graphml(graph, file_name)
|
| 22 |
|
| 23 |
@staticmethod
|
|
|
|
| 76 |
self.working_dir, f"{self.namespace}.graphml"
|
| 77 |
)
|
| 78 |
preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
|
| 79 |
+
if preloaded_graph:
|
| 80 |
+
print(
|
| 81 |
+
f"Loaded graph from {self._graphml_xml_file} with "
|
| 82 |
+
f"{preloaded_graph.number_of_nodes()} nodes, "
|
| 83 |
+
f"{preloaded_graph.number_of_edges()} edges"
|
|
|
|
| 84 |
)
|
| 85 |
self._graph = preloaded_graph or nx.Graph()
|
| 86 |
|
|
|
|
| 126 |
if self._graph.has_node(node_id):
|
| 127 |
self._graph.nodes[node_id].update(node_data)
|
| 128 |
else:
|
| 129 |
+
print(f"Node {node_id} not found in the graph for update.")
|
| 130 |
|
| 131 |
def upsert_edge(
|
| 132 |
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
|
|
|
|
| 139 |
if self._graph.has_edge(source_node_id, target_node_id):
|
| 140 |
self._graph.edges[(source_node_id, target_node_id)].update(edge_data)
|
| 141 |
else:
|
| 142 |
+
print(
|
| 143 |
+
f"Edge {source_node_id} -> {target_node_id} not found in the graph for update."
|
|
|
|
|
|
|
| 144 |
)
|
| 145 |
|
| 146 |
def delete_node(self, node_id: str):
|
|
|
|
| 151 |
"""
|
| 152 |
if self._graph.has_node(node_id):
|
| 153 |
self._graph.remove_node(node_id)
|
| 154 |
+
print(f"Node {node_id} deleted from the graph.")
|
| 155 |
else:
|
| 156 |
+
print(f"Node {node_id} not found in the graph for deletion.")
|
| 157 |
|
| 158 |
def clear(self):
|
| 159 |
"""
|
| 160 |
Clear the graph by removing all nodes and edges.
|
| 161 |
"""
|
| 162 |
self._graph.clear()
|
| 163 |
+
print(f"Graph {self.namespace} cleared.")
|
| 164 |
+
|
| 165 |
+
def reload(self):
|
| 166 |
+
"""
|
| 167 |
+
Reload the graph from the GraphML file.
|
| 168 |
+
"""
|
| 169 |
+
self.__post_init__()
|
graphgen/models/storage/kv/__init__.py
ADDED
|
File without changes
|