diff --git a/app.py b/app.py index dfd0edda5b989cea18c100802e8e7a4e0b70df8e..98b02601d90512c6a44703f8ff8e70d1963c4c1e 100644 --- a/app.py +++ b/app.py @@ -5,14 +5,12 @@ import tempfile from importlib.resources import files import gradio as gr -import pandas as pd +import ray from dotenv import load_dotenv -from graphgen.engine import Context, Engine, collect_ops -from graphgen.graphgen import GraphGen -from graphgen.models import OpenAIClient, Tokenizer -from graphgen.models.llm.limitter import RPM, TPM -from graphgen.utils import set_logger +from graphgen.engine import Engine +from graphgen.operators import operators +from graphgen.utils import CURRENT_LOGGER_VAR, set_logger from webui.base import WebuiParams from webui.i18n import Translate from webui.i18n import gettext as _ @@ -22,7 +20,6 @@ from webui.utils import cleanup_workspace, count_tokens, preview_file, setup_wor root_dir = files("webui").parent sys.path.append(root_dir) - load_dotenv() css = """ @@ -34,131 +31,136 @@ css = """ """ -def init_graph_gen(config: dict, env: dict) -> GraphGen: - # Set up working directory - log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache")) - set_logger(log_file, if_stream=True) - os.environ.update({k: str(v) for k, v in env.items()}) - - tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base")) - synthesizer_llm_client = OpenAIClient( - model=env.get("SYNTHESIZER_MODEL", ""), - base_url=env.get("SYNTHESIZER_BASE_URL", ""), - api_key=env.get("SYNTHESIZER_API_KEY", ""), - request_limit=True, - rpm=RPM(env.get("RPM", 1000)), - tpm=TPM(env.get("TPM", 50000)), - tokenizer=tokenizer_instance, - ) - trainee_llm_client = OpenAIClient( - model=env.get("TRAINEE_MODEL", ""), - base_url=env.get("TRAINEE_BASE_URL", ""), - api_key=env.get("TRAINEE_API_KEY", ""), - request_limit=True, - rpm=RPM(env.get("RPM", 1000)), - tpm=TPM(env.get("TPM", 50000)), - tokenizer=tokenizer_instance, - ) - - graph_gen = GraphGen( - working_dir=working_dir, - tokenizer_instance=tokenizer_instance, - synthesizer_llm_client=synthesizer_llm_client, - trainee_llm_client=trainee_llm_client, - ) - - return graph_gen - - -# pylint: disable=too-many-statements -def run_graphgen(params: WebuiParams, progress=gr.Progress()): - def sum_tokens(client): - return sum(u["total_tokens"] for u in client.token_usage) - +def _get_partition_params(params: WebuiParams): method = params.partition_method if method == "dfs": - partition_params = { + return { "max_units_per_community": params.dfs_max_units, } - elif method == "bfs": - partition_params = { + if method == "bfs": + return { "max_units_per_community": params.bfs_max_units, } - elif method == "leiden": - partition_params = { + if method == "leiden": + return { "max_size": params.leiden_max_size, "use_lcc": params.leiden_use_lcc, "random_seed": params.leiden_random_seed, } - else: # ece - partition_params = { - "max_units_per_community": params.ece_max_units, - "min_units_per_community": params.ece_min_units, - "max_tokens_per_community": params.ece_max_tokens, - "unit_sampling": params.ece_unit_sampling, - } + # ece + return { + "max_units_per_community": params.ece_max_units, + "min_units_per_community": params.ece_min_units, + "max_tokens_per_community": params.ece_max_tokens, + "unit_sampling": params.ece_unit_sampling, + } + + +# pylint: disable=too-many-statements +def run_graphgen(params: WebuiParams, progress=gr.Progress()): + # 1. Setup Workspace + log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache")) + driver_logger = set_logger(log_file, "GraphGeb", if_stream=True) + CURRENT_LOGGER_VAR.set(driver_logger) + + # 2. Setup Environment Variables for Ray Actors/LLM Init + # The refactored code relies on env vars in graphgen/common/init_llm.py + os.environ["SYNTHESIZER_BACKEND"] = "openai_api" # Assuming OpenAI compatible API + os.environ["SYNTHESIZER_BASE_URL"] = params.synthesizer_url + os.environ["SYNTHESIZER_API_KEY"] = params.api_key + os.environ["SYNTHESIZER_MODEL"] = params.synthesizer_model + os.environ["RPM"] = str(params.rpm) + os.environ["TPM"] = str(params.tpm) + os.environ["TOKENIZER_MODEL"] = params.tokenizer + + if params.if_trainee_model: + os.environ["TRAINEE_BACKEND"] = "openai_api" + os.environ["TRAINEE_BASE_URL"] = params.trainee_url + os.environ["TRAINEE_API_KEY"] = params.trainee_api_key + os.environ["TRAINEE_MODEL"] = params.trainee_model - pipeline = [ + # 3. Construct Pipeline Configuration (DAG) + nodes = [ { - "name": "read", - "op_key": "read", + "id": "read", + "op_name": "read", + "type": "source", + "dependencies": [], "params": { - "input_file": params.upload_file, + "input_path": [params.upload_file], }, }, { - "name": "chunk", - "deps": ["read"], - "op_key": "chunk", + "id": "chunk", + "op_name": "chunk", + "type": "map_batch", + "dependencies": ["read"], + "execution_params": {"replicas": 1}, "params": { "chunk_size": params.chunk_size, "chunk_overlap": params.chunk_overlap, }, }, { - "name": "build_kg", - "deps": ["chunk"], - "op_key": "build_kg", + "id": "build_kg", + "op_name": "build_kg", + "type": "map_batch", + "dependencies": ["chunk"], + "execution_params": {"replicas": 1, "batch_size": 128}, }, ] + last_node_id = "build_kg" + + # Optional: Quiz and Judge if params.if_trainee_model: - pipeline.append( - { - "name": "quiz_and_judge", - "deps": ["build_kg"], - "op_key": "quiz_and_judge", - "params": {"quiz_samples": params.quiz_samples, "re_judge": True}, - } - ) - pipeline.append( + nodes.append( { - "name": "partition", - "deps": ["quiz_and_judge"], - "op_key": "partition", + "id": "quiz", + "op_name": "quiz", + "type": "aggregate", # QuizService uses aggregate in config + "dependencies": ["build_kg"], + "execution_params": {"replicas": 1, "batch_size": 128}, "params": { - "method": params.partition_method, - "method_params": partition_params, + "quiz_samples": params.quiz_samples, + "concurrency_limit": 200, }, } ) - else: - pipeline.append( + + nodes.append( { - "name": "partition", - "deps": ["build_kg"], - "op_key": "partition", - "params": { - "method": params.partition_method, - "method_params": partition_params, - }, + "id": "judge", + "op_name": "judge", + "type": "map_batch", + "dependencies": ["quiz"], + "execution_params": {"replicas": 1, "batch_size": 128}, } ) - pipeline.append( + last_node_id = "judge" + + # Node: Partition + nodes.append( + { + "id": "partition", + "op_name": "partition", + "type": "aggregate", # PartitionService uses aggregate + "dependencies": [last_node_id], + "params": { + "method": params.partition_method, + "method_params": _get_partition_params(params), + }, + } + ) + + # Node: Generate + nodes.append( { - "name": "generate", - "deps": ["partition"], - "op_key": "generate", + "id": "generate", + "op_name": "generate", + "type": "map_batch", + "dependencies": ["partition"], + "execution_params": {"replicas": 1, "batch_size": 128}, "params": { "method": params.mode, "data_format": params.data_format, @@ -166,88 +168,50 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()): } ) - config = { - "if_trainee_model": params.if_trainee_model, - "read": {"input_file": params.upload_file}, - "pipeline": pipeline, - } + config = {"global_params": {"working_dir": working_dir}, "nodes": nodes} - env = { - "TOKENIZER_MODEL": params.tokenizer, - "SYNTHESIZER_BASE_URL": params.synthesizer_url, - "SYNTHESIZER_MODEL": params.synthesizer_model, - "TRAINEE_BASE_URL": params.trainee_url, - "TRAINEE_MODEL": params.trainee_model, - "SYNTHESIZER_API_KEY": params.api_key, - "TRAINEE_API_KEY": params.trainee_api_key, - "RPM": params.rpm, - "TPM": params.tpm, - } + try: + # 4. Initialize and Run Engine + # Initialize Ray if not already running (Engine handles this mostly, but good for safety) + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True, log_to_driver=True) - # Test API connection - test_api_connection( - env["SYNTHESIZER_BASE_URL"], - env["SYNTHESIZER_API_KEY"], - env["SYNTHESIZER_MODEL"], - ) - if config["if_trainee_model"]: - test_api_connection( - env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"] - ) + engine = Engine(config, operators) - # Initialize GraphGen - graph_gen = init_graph_gen(config, env) - graph_gen.clear() - graph_gen.progress_bar = progress + # Start with an empty dataset to kick off the pipeline + ds = ray.data.from_items([]) - try: - ctx = Context(config=config, graph_gen=graph_gen) - ops = collect_ops(config, graph_gen) - Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx) - - # Save output - output_data = graph_gen.qa_storage.data - with tempfile.NamedTemporaryFile( - mode="w", suffix=".jsonl", delete=False, encoding="utf-8" - ) as tmpfile: - json.dump(output_data, tmpfile, ensure_ascii=False) - output_file = tmpfile.name - - synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client) - trainee_tokens = ( - sum_tokens(graph_gen.trainee_llm_client) - if config["if_trainee_model"] - else 0 - ) - total_tokens = synthesizer_tokens + trainee_tokens - - data_frame = params.token_counter - try: - _update_data = [ - [data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)] - ] - new_df = pd.DataFrame(_update_data, columns=data_frame.columns) - data_frame = new_df - - except Exception as e: - raise gr.Error(f"DataFrame operation error: {str(e)}") - - return output_file, gr.DataFrame( - label="Token Stats", - headers=["Source Text Token Count", "Expected Token Usage", "Token Used"], - datatype="str", - interactive=False, - value=data_frame, - visible=True, - wrap=True, - ) + # Execute pipeline + results = engine.execute(ds) + + # 5. Process Output + # Extract the result from the 'generate' node + if "generate" in results: + result_ds = results["generate"] + + # Create a temporary file to save the output + with tempfile.NamedTemporaryFile( + mode="w", suffix=".jsonl", delete=False, encoding="utf-8" + ) as tmpfile: + # Iterate over rows and write to file + for row in result_ds.iter_rows(): + json.dump(row, tmpfile, ensure_ascii=False) + tmpfile.write("\n") + output_file = tmpfile.name + else: + raise gr.Error("Generation step failed to produce output.") + + # Note: Dynamic token counting from distributed actors is not directly available + # via client properties in the new architecture. We return the estimated stats from input. + + return output_file, params.token_counter except Exception as e: # pylint: disable=broad-except raise gr.Error(f"Error occurred: {str(e)}") finally: # Clean up workspace - cleanup_workspace(graph_gen.working_dir) + cleanup_workspace(working_dir) # Optional: keep for debugging or enable with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo: @@ -267,7 +231,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo: ("简体中文", "zh"), ], value="en", - # label=_("Language"), render=False, container=False, elem_classes=["center-row"], @@ -295,7 +258,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo: os.path.join(root_dir, "webui", "translation.json"), lang_btn, placeholder_langs=["en", "zh"], - persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0 + persistant=False, ): lang_btn.render() @@ -701,7 +664,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo: outputs=[output, token_counter], ) - if __name__ == "__main__": demo.queue(api_open=False, default_concurrency_limit=2) demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False) diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py index 3d0bc8001f41a37238d6d9603c4686df4a879f28..41136974aa388a59acef93aa92852f7cc63f8ac9 100644 --- a/graphgen/bases/__init__.py +++ b/graphgen/bases/__init__.py @@ -2,15 +2,11 @@ from .base_extractor import BaseExtractor from .base_generator import BaseGenerator from .base_kg_builder import BaseKGBuilder from .base_llm_wrapper import BaseLLMWrapper +from .base_operator import BaseOperator from .base_partitioner import BasePartitioner from .base_reader import BaseReader from .base_searcher import BaseSearcher from .base_splitter import BaseSplitter -from .base_storage import ( - BaseGraphStorage, - BaseKVStorage, - BaseListStorage, - StorageNameSpace, -) +from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace from .base_tokenizer import BaseTokenizer -from .datatypes import Chunk, QAPair, Token +from .datatypes import Chunk, Config, Node, QAPair, Token diff --git a/graphgen/bases/base_llm_wrapper.py b/graphgen/bases/base_llm_wrapper.py index fdeda602b9adbab918e0bb2c63f21d5f5600f54c..8b6dbec7fdf1958e05e90b990588d59f4a054c15 100644 --- a/graphgen/bases/base_llm_wrapper.py +++ b/graphgen/bases/base_llm_wrapper.py @@ -72,9 +72,3 @@ class BaseLLMWrapper(abc.ABC): filtered = filtered.strip() return filtered if filtered else text.strip() - - def shutdown(self) -> None: - """Shutdown the LLM engine if applicable.""" - - def restart(self) -> None: - """Reinitialize the LLM engine if applicable.""" diff --git a/graphgen/bases/base_operator.py b/graphgen/bases/base_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..300d31786ba0531d6226f10c1a62b2be8fbef3cd --- /dev/null +++ b/graphgen/bases/base_operator.py @@ -0,0 +1,57 @@ +import inspect +import os +from abc import ABC, abstractmethod +from typing import Iterable, Union + +import pandas as pd +import ray + +from graphgen.utils import CURRENT_LOGGER_VAR, set_logger + + +class BaseOperator(ABC): + def __init__(self, working_dir: str = "cache", op_name: str = None): + log_dir = os.path.join(working_dir, "logs") + self.op_name = op_name or self.__class__.__name__ + + try: + ctx = ray.get_runtime_context() + worker_id = ctx.get_actor_id() or ctx.get_worker_id() + worker_id_short = worker_id[-6:] if worker_id else "driver" + except Exception as e: + print( + "Warning: Could not get Ray worker ID, defaulting to 'local'. Exception:", + e, + ) + worker_id_short = "local" + + # e.g. cache/logs/ChunkService_a1b2c3.log + log_file = os.path.join(log_dir, f"{self.op_name}_{worker_id_short}.log") + + self.logger = set_logger( + log_file=log_file, name=f"{self.op_name}.{worker_id_short}", force=True + ) + + self.logger.info( + "[%s] Operator initialized on Worker %s", self.op_name, worker_id_short + ) + + def __call__( + self, batch: pd.DataFrame + ) -> Union[pd.DataFrame, Iterable[pd.DataFrame]]: + logger_token = CURRENT_LOGGER_VAR.set(self.logger) + try: + result = self.process(batch) + if inspect.isgenerator(result): + yield from result + else: + yield result + finally: + CURRENT_LOGGER_VAR.reset(logger_token) + + @abstractmethod + def process(self, batch): + raise NotImplementedError("Subclasses must implement the process method.") + + def get_logger(self): + return self.logger diff --git a/graphgen/bases/base_partitioner.py b/graphgen/bases/base_partitioner.py index d74ff563d4ff392fe868bc33059ee218035aca6b..d948e3a7820795da84bd992606c57d700e3a0b2a 100644 --- a/graphgen/bases/base_partitioner.py +++ b/graphgen/bases/base_partitioner.py @@ -7,7 +7,7 @@ from graphgen.bases.datatypes import Community class BasePartitioner(ABC): @abstractmethod - async def partition( + def partition( self, g: BaseGraphStorage, **kwargs: Any, @@ -20,39 +20,34 @@ class BasePartitioner(ABC): """ @staticmethod - async def community2batch( - communities: List[Community], g: BaseGraphStorage - ) -> list[ - tuple[ - list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] - ] + def community2batch( + comm: Community, g: BaseGraphStorage + ) -> tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] ]: """ Convert communities to batches of nodes and edges. - :param communities + :param comm: Community :param g: Graph storage instance :return: List of batches, each batch is a tuple of (nodes, edges) """ - batches = [] - for comm in communities: - nodes = comm.nodes - edges = comm.edges - nodes_data = [] - for node in nodes: - node_data = g.get_node(node) - if node_data: - nodes_data.append((node, node_data)) - edges_data = [] - for u, v in edges: - edge_data = g.get_edge(u, v) + nodes = comm.nodes + edges = comm.edges + nodes_data = [] + for node in nodes: + node_data = g.get_node(node) + if node_data: + nodes_data.append((node, node_data)) + edges_data = [] + for u, v in edges: + edge_data = g.get_edge(u, v) + if edge_data: + edges_data.append((u, v, edge_data)) + else: + edge_data = g.get_edge(v, u) if edge_data: - edges_data.append((u, v, edge_data)) - else: - edge_data = g.get_edge(v, u) - if edge_data: - edges_data.append((v, u, edge_data)) - batches.append((nodes_data, edges_data)) - return batches + edges_data.append((v, u, edge_data)) + return nodes_data, edges_data @staticmethod def _build_adjacency_list( diff --git a/graphgen/bases/base_reader.py b/graphgen/bases/base_reader.py index 8977846917be86136aab12e0a7f625e327446c3b..5d2af7355667f831efa97e53a62e2a7bc470c245 100644 --- a/graphgen/bases/base_reader.py +++ b/graphgen/bases/base_reader.py @@ -1,8 +1,10 @@ import os from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Union +import pandas as pd import requests +from ray.data import Dataset class BaseReader(ABC): @@ -10,56 +12,70 @@ class BaseReader(ABC): Abstract base class for reading and processing data. """ - def __init__(self, text_column: str = "content"): + def __init__(self, text_column: str = "content", modalities: list = None): self.text_column = text_column + self.modalities = modalities if modalities is not None else ["text"] @abstractmethod - def read(self, file_path: str) -> List[Dict[str, Any]]: + def read(self, input_path: Union[str, List[str]]) -> Dataset: """ Read data from the specified file path. - :param file_path: Path to the input file. - :return: List of dictionaries containing the data. + :param input_path: Path to the input file or list of file paths. + :return: Ray Dataset containing the read data. """ - @staticmethod - def filter(data: List[dict]) -> List[dict]: + def _should_keep_item(self, item: Dict[str, Any]) -> bool: + """ + Determine whether to keep the given item based on the text column. + + :param item: Dictionary representing a data entry. + :return: True if the item should be kept, False otherwise. """ - Filter out entries with empty or missing text in the specified column. + item_type = item.get("type") + assert item_type in [ + "text", + "image", + "table", + "equation", + "protein", + ], f"Unsupported item type: {item_type}" + if item_type == "text": + content = item.get(self.text_column, "").strip() + return bool(content) + return True - :param data: List of dictionaries containing the data. - :return: Filtered list of dictionaries. + def _validate_batch(self, batch: pd.DataFrame) -> pd.DataFrame: + """ + Validate data format. """ + if "type" not in batch.columns: + raise ValueError(f"Missing 'type' column. Found: {list(batch.columns)}") - def _image_exists(path_or_url: str, timeout: int = 3) -> bool: - """ - Check if an image exists at the given local path or URL. - :param path_or_url: Local file path or remote URL of the image. - :param timeout: Timeout for remote URL requests in seconds. - :return: True if the image exists, False otherwise. - """ - if not path_or_url: - return False - if not path_or_url.startswith(("http://", "https://", "ftp://")): - path = path_or_url.replace("file://", "", 1) - path = os.path.abspath(path) - return os.path.isfile(path) - try: - resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout) - return resp.status_code == 200 - except requests.RequestException: - return False + if "text" in batch["type"].values: + if self.text_column not in batch.columns: + raise ValueError( + f"Missing '{self.text_column}' column for text documents" + ) - filtered_data = [] - for item in data: - if item.get("type") == "text": - content = item.get("content", "").strip() - if content: - filtered_data.append(item) - elif item.get("type") in ("image", "table", "equation"): - img_path = item.get("img_path") - if _image_exists(img_path): - filtered_data.append(item) - else: - filtered_data.append(item) - return filtered_data + return batch + + @staticmethod + def _image_exists(path_or_url: str, timeout: int = 3) -> bool: + """ + Check if an image exists at the given local path or URL. + :param path_or_url: Local file path or remote URL of the image. + :param timeout: Timeout for remote URL requests in seconds. + :return: True if the image exists, False otherwise. + """ + if not path_or_url: + return False + if not path_or_url.startswith(("http://", "https://", "ftp://")): + path = path_or_url.replace("file://", "", 1) + path = os.path.abspath(path) + return os.path.isfile(path) + try: + resp = requests.head(path_or_url, allow_redirects=True, timeout=timeout) + return resp.status_code == 200 + except requests.RequestException: + return False diff --git a/graphgen/bases/base_splitter.py b/graphgen/bases/base_splitter.py index b2d1ad3a9f2e7dac135926d1906d29e701235d28..f77be6e41e887c4def5e5f822e89051e60ee1bf4 100644 --- a/graphgen/bases/base_splitter.py +++ b/graphgen/bases/base_splitter.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from typing import Callable, Iterable, List, Literal, Optional, Union from graphgen.bases.datatypes import Chunk -from graphgen.utils import logger +from graphgen.utils.log import logger class BaseSplitter(ABC): @@ -33,7 +33,7 @@ class BaseSplitter(ABC): """ Split the input text into smaller chunks. - :param text: The input text to be split. + :param text: The input text to be chunk. :return: A list of text chunks. """ @@ -111,7 +111,7 @@ class BaseSplitter(ABC): def _split_text_with_regex( text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]] ) -> List[str]: - # Now that we have the separator, split the text + # Now that we have the separator, chunk the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index bfcd658c1f0f08a884c41237a84496b403d2f235..ff7d2d1a05d2d40db3c22d951a2b0daf86420d94 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -16,23 +16,6 @@ class StorageNameSpace: """commit the storage operations after querying""" -class BaseListStorage(Generic[T], StorageNameSpace): - def all_items(self) -> list[T]: - raise NotImplementedError - - def get_by_index(self, index: int) -> Union[T, None]: - raise NotImplementedError - - def append(self, data: T): - raise NotImplementedError - - def upsert(self, data: list[T]): - raise NotImplementedError - - def drop(self): - raise NotImplementedError - - class BaseKVStorage(Generic[T], StorageNameSpace): def all_keys(self) -> list[str]: raise NotImplementedError @@ -58,6 +41,9 @@ class BaseKVStorage(Generic[T], StorageNameSpace): def drop(self): raise NotImplementedError + def reload(self): + raise NotImplementedError + class BaseGraphStorage(StorageNameSpace): def has_node(self, node_id: str) -> bool: @@ -105,3 +91,6 @@ class BaseGraphStorage(StorageNameSpace): def delete_node(self, node_id: str): raise NotImplementedError + + def reload(self): + raise NotImplementedError diff --git a/graphgen/bases/datatypes.py b/graphgen/bases/datatypes.py index cb3be3455b677769fc0b008d7baad1553e061733..df719fdf9a3542ef7c4f79aeb9408f70b7ff3c9e 100644 --- a/graphgen/bases/datatypes.py +++ b/graphgen/bases/datatypes.py @@ -2,6 +2,8 @@ import math from dataclasses import dataclass, field from typing import List, Union +from pydantic import BaseModel, Field, field_validator + @dataclass class Chunk: @@ -48,3 +50,45 @@ class Community: nodes: List[str] = field(default_factory=list) edges: List[tuple] = field(default_factory=list) metadata: dict = field(default_factory=dict) + + +class Node(BaseModel): + id: str = Field(..., description="unique node id") + op_name: str = Field(..., description="operator name") + type: str = Field( + ..., description="task type, e.g., map, filter, flatmap, aggregate, map_batch" + ) + params: dict = Field(default_factory=dict, description="operator parameters") + dependencies: List[str] = Field( + default_factory=list, description="list of dependent node ids" + ) + execution_params: dict = Field( + default_factory=dict, description="execution parameters like replicas, batch_size" + ) + + @classmethod + @field_validator("type") + def validate_type(cls, v: str) -> str: + valid_types = {"map", "filter", "flatmap", "aggregate", "map_batch"} + if v not in valid_types: + raise ValueError(f"Invalid node type: {v}. Must be one of {valid_types}.") + return v + + +class Config(BaseModel): + global_params: dict = Field( + default_factory=dict, description="global context for the computation graph" + ) + + nodes: List[Node] = Field( + ..., min_length=1, description="list of nodes in the computation graph" + ) + + @classmethod + @field_validator("nodes") + def validate_unique_ids(cls, v: List[Node]) -> List[Node]: + ids = [node.id for node in v] + if len(ids) != len(set(ids)): + duplicates = {id_ for id_ in ids if ids.count(id_) > 1} + raise ValueError(f"Duplicate node ids found: {duplicates}") + return v diff --git a/graphgen/common/__init__.py b/graphgen/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..deb99459f25d318436501fb8c015e93a610e37fa --- /dev/null +++ b/graphgen/common/__init__.py @@ -0,0 +1,2 @@ +from .init_llm import init_llm +from .init_storage import init_storage diff --git a/graphgen/common/init_llm.py b/graphgen/common/init_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..1e4f8cc732fda893a193a4db675e8a23f717d0ff --- /dev/null +++ b/graphgen/common/init_llm.py @@ -0,0 +1,177 @@ +import os +from typing import Any, Dict, Optional + +import ray + +from graphgen.bases import BaseLLMWrapper +from graphgen.common.init_storage import get_actor_handle +from graphgen.models import Tokenizer + + +class LLMServiceActor: + """ + A Ray actor class to wrap LLM wrapper instances for distributed usage. + """ + + def __init__(self, backend: str, config: Dict[str, Any]): + self.backend = backend + tokenizer_model = os.environ.get("TOKENIZER_MODEL", "cl100k_base") + tokenizer = Tokenizer(model_name=tokenizer_model) + config["tokenizer"] = tokenizer + + if backend == "http_api": + from graphgen.models.llm.api.http_client import HTTPClient + + self.llm_instance = HTTPClient(**config) + elif backend in ("openai_api", "azure_openai_api"): + from graphgen.models.llm.api.openai_client import OpenAIClient + + # pass in concrete backend to the OpenAIClient so that internally we can distinguish + # between OpenAI and Azure OpenAI + self.llm_instance = OpenAIClient(**config, backend=backend) + elif backend == "ollama_api": + from graphgen.models.llm.api.ollama_client import OllamaClient + + self.llm_instance = OllamaClient(**config) + elif backend == "huggingface": + from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper + + self.llm_instance = HuggingFaceWrapper(**config) + elif backend == "sglang": + from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper + + self.llm_instance = SGLangWrapper(**config) + + elif backend == "vllm": + from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper + + self.llm_instance = VLLMWrapper(**config) + else: + raise NotImplementedError(f"Backend {backend} is not implemented yet.") + + async def generate_answer( + self, text: str, history: Optional[list[str]] = None, **extra: Any + ) -> str: + return await self.llm_instance.generate_answer(text, history, **extra) + + async def generate_topk_per_token( + self, text: str, history: Optional[list[str]] = None, **extra: Any + ) -> list: + return await self.llm_instance.generate_topk_per_token(text, history, **extra) + + async def generate_inputs_prob( + self, text: str, history: Optional[list[str]] = None, **extra: Any + ) -> list: + return await self.llm_instance.generate_inputs_prob(text, history, **extra) + + def ready(self) -> bool: + """A simple method to check if the actor is ready.""" + return True + + +class LLMServiceProxy(BaseLLMWrapper): + """ + A proxy class to interact with the LLMServiceActor for distributed LLM operations. + """ + + def __init__(self, actor_name: str): + super().__init__() + self.actor_handle = get_actor_handle(actor_name) + self._create_local_tokenizer() + + async def generate_answer( + self, text: str, history: Optional[list[str]] = None, **extra: Any + ) -> str: + object_ref = self.actor_handle.generate_answer.remote(text, history, **extra) + return await object_ref + + async def generate_topk_per_token( + self, text: str, history: Optional[list[str]] = None, **extra: Any + ) -> list: + object_ref = self.actor_handle.generate_topk_per_token.remote( + text, history, **extra + ) + return await object_ref + + async def generate_inputs_prob( + self, text: str, history: Optional[list[str]] = None, **extra: Any + ) -> list: + object_ref = self.actor_handle.generate_inputs_prob.remote( + text, history, **extra + ) + return await object_ref + + def _create_local_tokenizer(self): + tokenizer_model = os.environ.get("TOKENIZER_MODEL", "cl100k_base") + self.tokenizer = Tokenizer(model_name=tokenizer_model) + + +class LLMFactory: + """ + A factory class to create LLM wrapper instances based on the specified backend. + Supported backends include: + - http_api: HTTPClient + - openai_api: OpenAIClient + - ollama_api: OllamaClient + - huggingface: HuggingFaceWrapper + - sglang: SGLangWrapper + """ + + @staticmethod + def create_llm( + model_type: str, backend: str, config: Dict[str, Any] + ) -> BaseLLMWrapper: + if not config: + raise ValueError( + f"No configuration provided for LLM {model_type} with backend {backend}." + ) + + actor_name = f"Actor_LLM_{model_type}" + try: + ray.get_actor(actor_name) + except ValueError: + print(f"Creating Ray actor for LLM {model_type} with backend {backend}.") + num_gpus = int(config.pop("num_gpus", 0)) + actor = ( + ray.remote(LLMServiceActor) + .options( + name=actor_name, + num_gpus=num_gpus, + lifetime="detached", + get_if_exists=True, + ) + .remote(backend, config) + ) + + # wait for actor to be ready + ray.get(actor.ready.remote()) + + return LLMServiceProxy(actor_name) + + +def _load_env_group(prefix: str) -> Dict[str, Any]: + """ + Collect environment variables with the given prefix into a dictionary, + stripping the prefix from the keys. + """ + return { + k[len(prefix) :].lower(): v + for k, v in os.environ.items() + if k.startswith(prefix) + } + + +def init_llm(model_type: str) -> Optional[BaseLLMWrapper]: + if model_type == "synthesizer": + prefix = "SYNTHESIZER_" + elif model_type == "trainee": + prefix = "TRAINEE_" + else: + raise NotImplementedError(f"Model type {model_type} is not implemented yet.") + config = _load_env_group(prefix) + # if config is empty, return None + if not config: + return None + backend = config.pop("backend") + llm_wrapper = LLMFactory.create_llm(model_type, backend, config) + return llm_wrapper diff --git a/graphgen/common/init_storage.py b/graphgen/common/init_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..b9358485b259076b652b0d635dfa1c6b51dd4806 --- /dev/null +++ b/graphgen/common/init_storage.py @@ -0,0 +1,262 @@ +from typing import Any, Dict, Union + +import ray + +from graphgen.bases.base_storage import BaseGraphStorage, BaseKVStorage + + +class KVStorageActor: + def __init__(self, backend: str, working_dir: str, namespace: str): + if backend == "json_kv": + from graphgen.models import JsonKVStorage + + self.kv = JsonKVStorage(working_dir, namespace) + elif backend == "rocksdb": + from graphgen.models import RocksDBKVStorage + + self.kv = RocksDBKVStorage(working_dir, namespace) + else: + raise ValueError(f"Unknown KV backend: {backend}") + + def data(self) -> Dict[str, Dict]: + return self.kv.data + + def all_keys(self) -> list[str]: + return self.kv.all_keys() + + def index_done_callback(self): + return self.kv.index_done_callback() + + def get_by_id(self, id: str) -> Dict: + return self.kv.get_by_id(id) + + def get_by_ids(self, ids: list[str], fields=None) -> list: + return self.kv.get_by_ids(ids, fields) + + def get_all(self) -> Dict[str, Dict]: + return self.kv.get_all() + + def filter_keys(self, data: list[str]) -> set[str]: + return self.kv.filter_keys(data) + + def upsert(self, data: dict) -> dict: + return self.kv.upsert(data) + + def drop(self): + return self.kv.drop() + + def reload(self): + return self.kv.reload() + + +class GraphStorageActor: + def __init__(self, backend: str, working_dir: str, namespace: str): + if backend == "networkx": + from graphgen.models import NetworkXStorage + + self.graph = NetworkXStorage(working_dir, namespace) + elif backend == "kuzu": + from graphgen.models import KuzuStorage + + self.graph = KuzuStorage(working_dir, namespace) + else: + raise ValueError(f"Unknown Graph backend: {backend}") + + def index_done_callback(self): + return self.graph.index_done_callback() + + def has_node(self, node_id: str) -> bool: + return self.graph.has_node(node_id) + + def has_edge(self, source_node_id: str, target_node_id: str): + return self.graph.has_edge(source_node_id, target_node_id) + + def node_degree(self, node_id: str) -> int: + return self.graph.node_degree(node_id) + + def edge_degree(self, src_id: str, tgt_id: str) -> int: + return self.graph.edge_degree(src_id, tgt_id) + + def get_node(self, node_id: str) -> Any: + return self.graph.get_node(node_id) + + def update_node(self, node_id: str, node_data: dict[str, str]): + return self.graph.update_node(node_id, node_data) + + def get_all_nodes(self) -> Any: + return self.graph.get_all_nodes() + + def get_edge(self, source_node_id: str, target_node_id: str): + return self.graph.get_edge(source_node_id, target_node_id) + + def update_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): + return self.graph.update_edge(source_node_id, target_node_id, edge_data) + + def get_all_edges(self) -> Any: + return self.graph.get_all_edges() + + def get_node_edges(self, source_node_id: str) -> Any: + return self.graph.get_node_edges(source_node_id) + + def upsert_node(self, node_id: str, node_data: dict[str, str]): + return self.graph.upsert_node(node_id, node_data) + + def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): + return self.graph.upsert_edge(source_node_id, target_node_id, edge_data) + + def delete_node(self, node_id: str): + return self.graph.delete_node(node_id) + + def reload(self): + return self.graph.reload() + + +def get_actor_handle(name: str): + try: + return ray.get_actor(name) + except ValueError as exc: + raise RuntimeError( + f"Actor {name} not found. Make sure it is created before accessing." + ) from exc + + +class RemoteKVStorageProxy(BaseKVStorage): + def __init__(self, namespace: str): + super().__init__() + self.namespace = namespace + self.actor_name = f"Actor_KV_{namespace}" + self.actor = get_actor_handle(self.actor_name) + + def data(self) -> Dict[str, Any]: + return ray.get(self.actor.data.remote()) + + def all_keys(self) -> list[str]: + return ray.get(self.actor.all_keys.remote()) + + def index_done_callback(self): + return ray.get(self.actor.index_done_callback.remote()) + + def get_by_id(self, id: str) -> Union[Any, None]: + return ray.get(self.actor.get_by_id.remote(id)) + + def get_by_ids(self, ids: list[str], fields=None) -> list[Any]: + return ray.get(self.actor.get_by_ids.remote(ids, fields)) + + def get_all(self) -> Dict[str, Any]: + return ray.get(self.actor.get_all.remote()) + + def filter_keys(self, data: list[str]) -> set[str]: + return ray.get(self.actor.filter_keys.remote(data)) + + def upsert(self, data: Dict[str, Any]): + return ray.get(self.actor.upsert.remote(data)) + + def drop(self): + return ray.get(self.actor.drop.remote()) + + def reload(self): + return ray.get(self.actor.reload.remote()) + + +class RemoteGraphStorageProxy(BaseGraphStorage): + def __init__(self, namespace: str): + super().__init__() + self.namespace = namespace + self.actor_name = f"Actor_Graph_{namespace}" + self.actor = get_actor_handle(self.actor_name) + + def index_done_callback(self): + return ray.get(self.actor.index_done_callback.remote()) + + def has_node(self, node_id: str) -> bool: + return ray.get(self.actor.has_node.remote(node_id)) + + def has_edge(self, source_node_id: str, target_node_id: str): + return ray.get(self.actor.has_edge.remote(source_node_id, target_node_id)) + + def node_degree(self, node_id: str) -> int: + return ray.get(self.actor.node_degree.remote(node_id)) + + def edge_degree(self, src_id: str, tgt_id: str) -> int: + return ray.get(self.actor.edge_degree.remote(src_id, tgt_id)) + + def get_node(self, node_id: str) -> Any: + return ray.get(self.actor.get_node.remote(node_id)) + + def update_node(self, node_id: str, node_data: dict[str, str]): + return ray.get(self.actor.update_node.remote(node_id, node_data)) + + def get_all_nodes(self) -> Any: + return ray.get(self.actor.get_all_nodes.remote()) + + def get_edge(self, source_node_id: str, target_node_id: str): + return ray.get(self.actor.get_edge.remote(source_node_id, target_node_id)) + + def update_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): + return ray.get( + self.actor.update_edge.remote(source_node_id, target_node_id, edge_data) + ) + + def get_all_edges(self) -> Any: + return ray.get(self.actor.get_all_edges.remote()) + + def get_node_edges(self, source_node_id: str) -> Any: + return ray.get(self.actor.get_node_edges.remote(source_node_id)) + + def upsert_node(self, node_id: str, node_data: dict[str, str]): + return ray.get(self.actor.upsert_node.remote(node_id, node_data)) + + def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): + return ray.get( + self.actor.upsert_edge.remote(source_node_id, target_node_id, edge_data) + ) + + def delete_node(self, node_id: str): + return ray.get(self.actor.delete_node.remote(node_id)) + + def reload(self): + return ray.get(self.actor.reload.remote()) + + +class StorageFactory: + """ + Factory class to create storage instances based on backend. + """ + + @staticmethod + def create_storage(backend: str, working_dir: str, namespace: str): + if backend in ["json_kv", "rocksdb"]: + actor_name = f"Actor_KV_{namespace}" + try: + ray.get_actor(actor_name) + except ValueError: + ray.remote(KVStorageActor).options( + name=actor_name, + lifetime="detached", + get_if_exists=True, + ).remote(backend, working_dir, namespace) + return RemoteKVStorageProxy(namespace) + if backend in ["networkx", "kuzu"]: + actor_name = f"Actor_Graph_{namespace}" + try: + ray.get_actor(actor_name) + except ValueError: + ray.remote(GraphStorageActor).options( + name=actor_name, + lifetime="detached", + get_if_exists=True, + ).remote(backend, working_dir, namespace) + return RemoteGraphStorageProxy(namespace) + raise ValueError(f"Unknown storage backend: {backend}") + + +def init_storage(backend: str, working_dir: str, namespace: str): + return StorageFactory.create_storage(backend, working_dir, namespace) diff --git a/graphgen/configs/aggregated_config.yaml b/graphgen/configs/aggregated_config.yaml deleted file mode 100644 index 9c53ec9c136591a625d70d8b3d4ed18768ddf176..0000000000000000000000000000000000000000 --- a/graphgen/configs/aggregated_config.yaml +++ /dev/null @@ -1,41 +0,0 @@ -pipeline: - - name: read_step # step name is unique in the pipeline, and can be referenced by other steps - op_key: read - params: - input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg_step depends on chunk_step - - - name: quiz_and_judge_step - op_key: quiz_and_judge - deps: [build_kg_step] # quiz_and_judge depends on build_kg_step - params: - quiz_samples: 2 # number of quiz samples to generate - re_judge: false # whether to re-judge the existing quiz samples - - - name: partition_step - op_key: partition - deps: [quiz_and_judge_step] # partition_step depends on quiz_and_judge_step - params: - method: ece # ece is a custom partition method based on comprehension loss - method_params: - max_units_per_community: 20 # max nodes and edges per community - min_units_per_community: 5 # min nodes and edges per community - max_tokens_per_community: 10240 # max tokens per community - unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: aggregated # atomic, aggregated, multi_hop, cot, vqa - data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/atomic_config.yaml b/graphgen/configs/atomic_config.yaml deleted file mode 100644 index f8ae22183e64435cb4cf7d92e74a60ae11265160..0000000000000000000000000000000000000000 --- a/graphgen/configs/atomic_config.yaml +++ /dev/null @@ -1,31 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg depends on chunk_step - - - name: partition_step - op_key: partition - deps: [build_kg] # partition_step depends on build_kg - params: - method: dfs # partition method, support: dfs, bfs, ece, leiden - method_params: - max_units_per_community: 1 # atomic partition, one node or edge per community - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: atomic # atomic, aggregated, multi_hop, cot, vqa - data_format: Alpaca # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/cot_config.yaml b/graphgen/configs/cot_config.yaml deleted file mode 100644 index b09e341d007b3a10ce066059611599a0c4bf218b..0000000000000000000000000000000000000000 --- a/graphgen/configs/cot_config.yaml +++ /dev/null @@ -1,33 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg depends on chunk_step - - - name: partition_step - op_key: partition - deps: [build_kg_step] # partition_step depends on build_kg - params: - method: leiden # leiden is a partitioner detection algorithm - method_params: - max_size: 20 # Maximum size of communities - use_lcc: false # whether to use the largest connected component - random_seed: 42 # random seed for partitioning - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: cot # atomic, aggregated, multi_hop, cot, vqa - data_format: Sharegpt # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/multi_hop_config.yaml b/graphgen/configs/multi_hop_config.yaml deleted file mode 100644 index 4b8051b40e6865de78654d5af59a91a292ea3159..0000000000000000000000000000000000000000 --- a/graphgen/configs/multi_hop_config.yaml +++ /dev/null @@ -1,34 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg_step depends on chunk_step - - - name: partition_step - op_key: partition - deps: [build_kg_step] # partition_step depends on build_kg_step - params: - method: ece # ece is a custom partition method based on comprehension loss - method_params: - max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3 - min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3 - max_tokens_per_community: 10240 # max tokens per community - unit_sampling: random # unit sampling strategy, support: random, max_loss, min_loss - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: multi_hop # atomic, aggregated, multi_hop, cot, vqa - data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/schema_guided_extraction_config.yaml b/graphgen/configs/schema_guided_extraction_config.yaml deleted file mode 100644 index 8d142ef6f28d6b91781ff68d7348bef8b2b55814..0000000000000000000000000000000000000000 --- a/graphgen/configs/schema_guided_extraction_config.yaml +++ /dev/null @@ -1,20 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/extract_demo.txt # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 20480 - chunk_overlap: 2000 - separators: [] - - - name: extract_step - op_key: extract - deps: [chunk_step] # extract_step depends on chunk_step - params: - method: schema_guided # extraction method, support: schema_guided - schema_file: graphgen/templates/extraction/schemas/legal_contract.json # schema file path for schema_guided method diff --git a/graphgen/configs/search_dna_config.yaml b/graphgen/configs/search_dna_config.yaml deleted file mode 100644 index f53a5eb8712279c2690d896da6ad8493dbbcfaec..0000000000000000000000000000000000000000 --- a/graphgen/configs/search_dna_config.yaml +++ /dev/null @@ -1,17 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/search_dna_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: search_step - op_key: search - deps: [read_step] # search_step depends on read_step - params: - data_sources: [ncbi] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral - ncbi_params: - email: test@example.com # NCBI requires an email address - tool: GraphGen # tool name for NCBI API - use_local_blast: true # whether to use local blast for DNA search - local_blast_db: refseq_release/refseq_release # path to local BLAST database (without .nhr extension) - diff --git a/graphgen/configs/search_protein_config.yaml b/graphgen/configs/search_protein_config.yaml deleted file mode 100644 index bfbf84eb1180d817065ea7718f8187a1c1342ee2..0000000000000000000000000000000000000000 --- a/graphgen/configs/search_protein_config.yaml +++ /dev/null @@ -1,15 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/search_protein_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: search_step - op_key: search - deps: [read_step] # search_step depends on read_step - params: - data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot - uniprot_params: - use_local_blast: true # whether to use local blast for uniprot search - local_blast_db: /your_path/2024_01/uniprot_sprot # format: /path/to/${RELEASE}/uniprot_sprot - # options: uniprot_sprot (recommended, high quality), uniprot_trembl, or uniprot_${RELEASE} (merged database) diff --git a/graphgen/configs/search_rna_config.yaml b/graphgen/configs/search_rna_config.yaml deleted file mode 100644 index 1042298820cb237dd2e1f09dbbd7f853f73cb9f1..0000000000000000000000000000000000000000 --- a/graphgen/configs/search_rna_config.yaml +++ /dev/null @@ -1,14 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/search_rna_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: search_step - op_key: search - deps: [read_step] # search_step depends on read_step - params: - data_sources: [rnacentral] # data source for searcher, support: wikipedia, google, uniprot, ncbi, rnacentral - rnacentral_params: - use_local_blast: true # whether to use local blast for RNA search - local_blast_db: rnacentral_ensembl_gencode_YYYYMMDD/ensembl_gencode_YYYYMMDD # path to local BLAST database (without .nhr extension) diff --git a/graphgen/configs/vqa_config.yaml b/graphgen/configs/vqa_config.yaml deleted file mode 100644 index 06eba5c40cf65025ffd6cca6fd8d08984f26fbef..0000000000000000000000000000000000000000 --- a/graphgen/configs/vqa_config.yaml +++ /dev/null @@ -1,32 +0,0 @@ -pipeline: - - name: read_step - op_key: read - params: - input_file: resources/input_examples/vqa_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples - - - name: chunk_step - op_key: chunk - deps: [read_step] # chunk_step depends on read_step - params: - chunk_size: 1024 # chunk size for text splitting - chunk_overlap: 100 # chunk overlap for text splitting - - - name: build_kg_step - op_key: build_kg - deps: [chunk_step] # build_kg depends on chunk_step - - - name: partition_step - op_key: partition - deps: [build_kg_step] # partition_step depends on build_kg_step - params: - method: anchor_bfs # partition method - method_params: - anchor_type: image # node type to select anchor nodes - max_units_per_community: 10 # atomic partition, one node or edge per community - - - name: generate_step - op_key: generate - deps: [partition_step] # generate_step depends on partition_step - params: - method: vqa # atomic, aggregated, multi_hop, cot, vqa - data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/engine.py b/graphgen/engine.py index 2989226c09f48fcdd23ee20a0692c6f9ba7081d6..62ab52818726b8c65aeaff5e15b9d8fedc30906d 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -1,125 +1,210 @@ -""" -orchestration engine for GraphGen -""" +import inspect +import logging +from collections import defaultdict, deque +from functools import wraps +from typing import Any, Callable, Dict, List, Set -import threading -import traceback -from typing import Any, Callable, List +import ray +import ray.data +from graphgen.bases import Config, Node +from graphgen.utils import logger -class Context(dict): - _lock = threading.Lock() - def set(self, k, v): - with self._lock: - self[k] = v - - def get(self, k, default=None): - with self._lock: - return super().get(k, default) - - -class OpNode: +class Engine: def __init__( - self, name: str, deps: List[str], func: Callable[["OpNode", Context], Any] + self, config: Dict[str, Any], functions: Dict[str, Callable], **ray_init_kwargs ): - self.name, self.deps, self.func = name, deps, func - + self.config = Config(**config) + self.global_params = self.config.global_params + self.functions = functions + self.datasets: Dict[str, ray.data.Dataset] = {} + + if not ray.is_initialized(): + context = ray.init( + ignore_reinit_error=True, + logging_level=logging.ERROR, + log_to_driver=True, + **ray_init_kwargs, + ) + logger.info("Ray Dashboard URL: %s", context.dashboard_url) -class Engine: - def __init__(self, max_workers: int = 4): - self.max_workers = max_workers - - def run(self, ops: List[OpNode], ctx: Context): - self._validate(ops) - name2op = {operation.name: operation for operation in ops} - - # topological sort - graph = {n: set(name2op[n].deps) for n in name2op} - topo = [] - q = [n for n, d in graph.items() if not d] - while q: - cur = q.pop(0) - topo.append(cur) - for child in [c for c, d in graph.items() if cur in d]: - graph[child].remove(cur) - if not graph[child]: - q.append(child) - - if len(topo) != len(ops): + @staticmethod + def _topo_sort(nodes: List[Node]) -> List[Node]: + id_to_node: Dict[str, Node] = {} + for n in nodes: + id_to_node[n.id] = n + + indeg: Dict[str, int] = {nid: 0 for nid in id_to_node} + adj: Dict[str, List[str]] = defaultdict(list) + + for n in nodes: + nid = n.id + deps: List[str] = n.dependencies + uniq_deps: Set[str] = set(deps) + for d in uniq_deps: + if d not in id_to_node: + raise ValueError( + f"The dependency node id {d} of node {nid} is not defined in the configuration." + ) + indeg[nid] += 1 + adj[d].append(nid) + + zero_deg: deque = deque( + [id_to_node[nid] for nid, deg in indeg.items() if deg == 0] + ) + sorted_nodes: List[Node] = [] + + while zero_deg: + cur = zero_deg.popleft() + sorted_nodes.append(cur) + cur_id = cur.id + for nb_id in adj.get(cur_id, []): + indeg[nb_id] -= 1 + if indeg[nb_id] == 0: + zero_deg.append(id_to_node[nb_id]) + + if len(sorted_nodes) != len(nodes): + remaining = [nid for nid, deg in indeg.items() if deg > 0] raise ValueError( - "Cyclic dependencies detected among operations." - "Please check your configuration." + f"The configuration contains cycles, unable to execute. Remaining nodes with indegree > 0: {remaining}" ) - # semaphore for max_workers - sem = threading.Semaphore(self.max_workers) - done = {n: threading.Event() for n in name2op} - exc = {} - - def _exec(n: str): - with sem: - for d in name2op[n].deps: - done[d].wait() - if any(d in exc for d in name2op[n].deps): - exc[n] = Exception("Skipped due to failed dependencies") - done[n].set() - return - try: - name2op[n].func(name2op[n], ctx) - except Exception: - exc[n] = traceback.format_exc() - done[n].set() - - ts = [threading.Thread(target=_exec, args=(n,), daemon=True) for n in topo] - for t in ts: - t.start() - for t in ts: - t.join() - if exc: - raise RuntimeError( - "Some operations failed:\n" - + "\n".join(f"---- {op} ----\n{tb}" for op, tb in exc.items()) + return sorted_nodes + + def _get_input_dataset( + self, node: Node, initial_ds: ray.data.Dataset + ) -> ray.data.Dataset: + deps = node.dependencies + + if not deps: + return initial_ds + + if len(deps) == 1: + return self.datasets[deps[0]] + + main_ds = self.datasets[deps[0]] + other_dss = [self.datasets[d] for d in deps[1:]] + return main_ds.union(*other_dss) + + def _execute_node(self, node: Node, initial_ds: ray.data.Dataset): + def _filter_kwargs( + func_or_class: Callable, + global_params: Dict[str, Any], + func_params: Dict[str, Any], + ) -> Dict[str, Any]: + """ + 1. global_params: only when specified in function signature, will be passed + 2. func_params: pass specified params first, then **kwargs if exists + """ + try: + sig = inspect.signature(func_or_class) + except ValueError: + return {} + + params = sig.parameters + final_kwargs = {} + + has_var_keywords = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() + ) + valid_keys = set(params.keys()) + for k, v in global_params.items(): + if k in valid_keys: + final_kwargs[k] = v + + for k, v in func_params.items(): + if k in valid_keys or has_var_keywords: + final_kwargs[k] = v + return final_kwargs + + if node.op_name not in self.functions: + raise ValueError(f"Operator {node.op_name} not found for node {node.id}") + + op_handler = self.functions[node.op_name] + node_params = _filter_kwargs(op_handler, self.global_params, node.params or {}) + + if node.type == "source": + self.datasets[node.id] = op_handler(**node_params) + return + + input_ds = self._get_input_dataset(node, initial_ds) + + if inspect.isclass(op_handler): + execution_params = node.execution_params or {} + replicas = execution_params.get("replicas", 1) + batch_size = ( + int(execution_params.get("batch_size")) + if "batch_size" in execution_params + else "default" ) + compute_resources = execution_params.get("compute_resources", {}) + + if node.type == "aggregate": + self.datasets[node.id] = input_ds.repartition(1).map_batches( + op_handler, + compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1), + batch_size=None, # aggregate processes the whole dataset at once + num_gpus=compute_resources.get("num_gpus", 0) + if compute_resources + else 0, + fn_constructor_kwargs=node_params, + batch_format="pandas", + ) + else: + # others like map, filter, flatmap, map_batch let actors process data inside batches + self.datasets[node.id] = input_ds.map_batches( + op_handler, + compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas), + batch_size=batch_size, + num_gpus=compute_resources.get("num_gpus", 0) + if compute_resources + else 0, + fn_constructor_kwargs=node_params, + batch_format="pandas", + ) - @staticmethod - def _validate(ops: List[OpNode]): - name_set = set() - for op in ops: - if op.name in name_set: - raise ValueError(f"Duplicate operation name: {op.name}") - name_set.add(op.name) - for op in ops: - for dep in op.deps: - if dep not in name_set: - raise ValueError( - f"Operation {op.name} has unknown dependency: {dep}" - ) + else: + @wraps(op_handler) + def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]: + return op_handler(row_or_batch, **node_params) + + if node.type == "map": + self.datasets[node.id] = input_ds.map(func_wrapper) + elif node.type == "filter": + self.datasets[node.id] = input_ds.filter(func_wrapper) + elif node.type == "flatmap": + self.datasets[node.id] = input_ds.flat_map(func_wrapper) + elif node.type == "aggregate": + self.datasets[node.id] = input_ds.repartition(1).map_batches( + func_wrapper, batch_format="default" + ) + elif node.type == "map_batch": + self.datasets[node.id] = input_ds.map_batches(func_wrapper) + else: + raise ValueError( + f"Unsupported node type {node.type} for node {node.id}" + ) -def collect_ops(config: dict, graph_gen) -> List[OpNode]: - """ - build operation nodes from yaml config - :param config - :param graph_gen - """ - ops: List[OpNode] = [] - for stage in config["pipeline"]: - name = stage["name"] - method_name = stage.get("op_key") - method = getattr(graph_gen, method_name) - deps = stage.get("deps", []) + @staticmethod + def _find_leaf_nodes(nodes: List[Node]) -> Set[str]: + all_ids = {n.id for n in nodes} + deps_set = set() + for n in nodes: + deps_set.update(n.dependencies) + return all_ids - deps_set - if "params" in stage: + def execute(self, initial_ds: ray.data.Dataset) -> Dict[str, ray.data.Dataset]: + sorted_nodes = self._topo_sort(self.config.nodes) - def func(self, ctx, _method=method, _params=stage.get("params", {})): - return _method(_params) + for node in sorted_nodes: + self._execute_node(node, initial_ds) - else: + leaf_nodes = self._find_leaf_nodes(sorted_nodes) - def func(self, ctx, _method=method): - return _method() + @ray.remote + def _fetch_result(ds: ray.data.Dataset) -> List[Any]: + return ds.take_all() - op_node = OpNode(name=name, deps=deps, func=func) - ops.append(op_node) - return ops + return {node_id: self.datasets[node_id] for node_id in leaf_nodes} diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py deleted file mode 100644 index bc7e77425489f07ba0cdcd9d63ee4aea4f6db87f..0000000000000000000000000000000000000000 --- a/graphgen/graphgen.py +++ /dev/null @@ -1,295 +0,0 @@ -import os -import time -from typing import Dict - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.bases.datatypes import Chunk -from graphgen.models import ( - JsonKVStorage, - JsonListStorage, - NetworkXStorage, - OpenAIClient, - Tokenizer, -) -from graphgen.operators import ( - build_kg, - chunk_documents, - extract_info, - generate_qas, - init_llm, - judge_statement, - partition_kg, - quiz, - read_files, - search_all, -) -from graphgen.utils import async_to_sync_method, compute_mm_hash, logger - -sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - - -class GraphGen: - def __init__( - self, - unique_id: int = int(time.time()), - working_dir: str = os.path.join(sys_path, "cache"), - tokenizer_instance: Tokenizer = None, - synthesizer_llm_client: OpenAIClient = None, - trainee_llm_client: OpenAIClient = None, - progress_bar: gr.Progress = None, - ): - self.unique_id: int = unique_id - self.working_dir: str = working_dir - - # llm - self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer( - model_name=os.getenv("TOKENIZER_MODEL", "cl100k_base") - ) - - self.synthesizer_llm_client: BaseLLMWrapper = ( - synthesizer_llm_client or init_llm("synthesizer") - ) - self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client - - self.full_docs_storage: JsonKVStorage = JsonKVStorage( - self.working_dir, namespace="full_docs" - ) - self.chunks_storage: JsonKVStorage = JsonKVStorage( - self.working_dir, namespace="chunks" - ) - self.graph_storage: NetworkXStorage = NetworkXStorage( - self.working_dir, namespace="graph" - ) - self.rephrase_storage: JsonKVStorage = JsonKVStorage( - self.working_dir, namespace="rephrase" - ) - self.partition_storage: JsonListStorage = JsonListStorage( - self.working_dir, namespace="partition" - ) - self.search_storage: JsonKVStorage = JsonKVStorage( - os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), - namespace="search", - ) - self.qa_storage: JsonListStorage = JsonListStorage( - os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), - namespace="qa", - ) - self.extract_storage: JsonKVStorage = JsonKVStorage( - os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"), - namespace="extraction", - ) - - # webui - self.progress_bar: gr.Progress = progress_bar - - @async_to_sync_method - async def read(self, read_config: Dict): - """ - read files from input sources - """ - doc_stream = read_files(**read_config, cache_dir=self.working_dir) - - batch = {} - for doc in doc_stream: - doc_id = compute_mm_hash(doc, prefix="doc-") - batch[doc_id] = doc - - # TODO: configurable whether to use coreference resolution - - _add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys())) - new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys} - if len(new_docs) == 0: - logger.warning("All documents are already in the storage") - return - self.full_docs_storage.upsert(new_docs) - self.full_docs_storage.index_done_callback() - - @async_to_sync_method - async def chunk(self, chunk_config: Dict): - """ - chunk documents into smaller pieces from full_docs_storage if not already present - """ - - new_docs = self.full_docs_storage.get_all() - if len(new_docs) == 0: - logger.warning("All documents are already in the storage") - return - - inserting_chunks = await chunk_documents( - new_docs, - self.tokenizer_instance, - self.progress_bar, - **chunk_config, - ) - - _add_chunk_keys = self.chunks_storage.filter_keys(list(inserting_chunks.keys())) - inserting_chunks = { - k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys - } - - if len(inserting_chunks) == 0: - logger.warning("All chunks are already in the storage") - return - - self.chunks_storage.upsert(inserting_chunks) - self.chunks_storage.index_done_callback() - - @async_to_sync_method - async def build_kg(self): - """ - build knowledge graph from text chunks - """ - # Step 1: get new chunks - inserting_chunks = self.chunks_storage.get_all() - - if len(inserting_chunks) == 0: - logger.warning("All chunks are already in the storage") - return - - logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks)) - # Step 2: build knowledge graph from new chunks - _add_entities_and_relations = await build_kg( - llm_client=self.synthesizer_llm_client, - kg_instance=self.graph_storage, - chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()], - progress_bar=self.progress_bar, - ) - if not _add_entities_and_relations: - logger.warning("No entities or relations extracted from text chunks") - return - - # Step 3: upsert new entities and relations to the graph storage - self.graph_storage.index_done_callback() - - return _add_entities_and_relations - - @async_to_sync_method - async def search(self, search_config: Dict): - logger.info("[Search] %s ...", ", ".join(search_config["data_sources"])) - - seeds = self.full_docs_storage.get_all() - if len(seeds) == 0: - logger.warning("All documents are already been searched") - return - search_results = await search_all( - seed_data=seeds, - search_config=search_config, - ) - - _add_search_keys = self.search_storage.filter_keys(list(search_results.keys())) - search_results = { - k: v for k, v in search_results.items() if k in _add_search_keys - } - if len(search_results) == 0: - logger.warning("All search results are already in the storage") - return - self.search_storage.upsert(search_results) - self.search_storage.index_done_callback() - - @async_to_sync_method - async def quiz_and_judge(self, quiz_and_judge_config: Dict): - logger.warning( - "Quiz and Judge operation needs trainee LLM client." - " Make sure to provide one." - ) - max_samples = quiz_and_judge_config["quiz_samples"] - await quiz( - self.synthesizer_llm_client, - self.graph_storage, - self.rephrase_storage, - max_samples, - progress_bar=self.progress_bar, - ) - - # TODO: assert trainee_llm_client is valid before judge - if not self.trainee_llm_client: - # TODO: shutdown existing synthesizer_llm_client properly - logger.info("No trainee LLM client provided, initializing a new one.") - self.synthesizer_llm_client.shutdown() - self.trainee_llm_client = init_llm("trainee") - - re_judge = quiz_and_judge_config["re_judge"] - _update_relations = await judge_statement( - self.trainee_llm_client, - self.graph_storage, - self.rephrase_storage, - re_judge, - progress_bar=self.progress_bar, - ) - - self.rephrase_storage.index_done_callback() - _update_relations.index_done_callback() - - logger.info("Shutting down trainee LLM client.") - self.trainee_llm_client.shutdown() - self.trainee_llm_client = None - logger.info("Restarting synthesizer LLM client.") - self.synthesizer_llm_client.restart() - - @async_to_sync_method - async def partition(self, partition_config: Dict): - batches = await partition_kg( - self.graph_storage, - self.chunks_storage, - self.tokenizer_instance, - partition_config, - ) - self.partition_storage.upsert(batches) - return batches - - @async_to_sync_method - async def extract(self, extract_config: Dict): - logger.info("Extracting information from given chunks...") - - results = await extract_info( - self.synthesizer_llm_client, - self.chunks_storage, - extract_config, - progress_bar=self.progress_bar, - ) - if not results: - logger.warning("No information extracted") - return - - self.extract_storage.upsert(results) - self.extract_storage.index_done_callback() - - @async_to_sync_method - async def generate(self, generate_config: Dict): - - batches = self.partition_storage.data - if not batches: - logger.warning("No partitions found for QA generation") - return - - # Step 2: generate QA pairs - results = await generate_qas( - self.synthesizer_llm_client, - batches, - generate_config, - progress_bar=self.progress_bar, - ) - - if not results: - logger.warning("No QA pairs generated") - return - - # Step 3: store the generated QA pairs - self.qa_storage.upsert(results) - self.qa_storage.index_done_callback() - - @async_to_sync_method - async def clear(self): - self.full_docs_storage.drop() - self.chunks_storage.drop() - self.search_storage.drop() - self.graph_storage.clear() - self.rephrase_storage.drop() - self.qa_storage.drop() - - logger.info("All caches are cleared") - - # TODO: add data filtering step here in the future - # graph_gen.filter(filter_config=config["filter"]) diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 3ef1ff69b7538ecceca196be96f66075ce2a4956..21344d740140ea45340152fc4b6aaa39b079c33d 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -18,7 +18,6 @@ from .partitioner import ( ) from .reader import ( CSVReader, - JSONLReader, JSONReader, ParquetReader, PDFReader, @@ -33,5 +32,11 @@ from .searcher.kg.wiki_search import WikiSearch from .searcher.web.bing_search import BingSearch from .searcher.web.google_search import GoogleSearch from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter -from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage, RocksDBCache +from .storage import ( + JsonKVStorage, + KuzuStorage, + NetworkXStorage, + RocksDBCache, + RocksDBKVStorage, +) from .tokenizer import Tokenizer diff --git a/graphgen/models/extractor/schema_guided_extractor.py b/graphgen/models/extractor/schema_guided_extractor.py index 70c45502c2e267b8ea1516fbe7b62b645e6944f0..7480194606d74b0201db5e5e8c191e1ee93c9b89 100644 --- a/graphgen/models/extractor/schema_guided_extractor.py +++ b/graphgen/models/extractor/schema_guided_extractor.py @@ -60,8 +60,8 @@ class SchemaGuidedExtractor(BaseExtractor): return prompt async def extract(self, chunk: dict) -> dict: - _chunk_id = list(chunk.keys())[0] - text = chunk[_chunk_id].get("content", "") + _chunk_id = chunk.get("_chunk_id", "") + text = chunk.get("content", "") prompt = self.build_prompt(text) response = await self.llm_client.generate_answer(prompt) @@ -88,9 +88,7 @@ class SchemaGuidedExtractor(BaseExtractor): return {} @staticmethod - async def merge_extractions( - extraction_list: List[Dict[str, dict]] - ) -> Dict[str, dict]: + def merge_extractions(extraction_list: List[Dict[str, dict]]) -> Dict[str, dict]: """ Merge multiple extraction results based on their hashes. :param extraction_list: List of extraction results, each is a dict with hash as key and record as value. diff --git a/graphgen/models/generator/vqa_generator.py b/graphgen/models/generator/vqa_generator.py index eefbdd1c465ec27e81b5134f58e13c4a1eaf9e76..91b448627cf2f4b2f8c72f88e5d3dc3a6331d472 100644 --- a/graphgen/models/generator/vqa_generator.py +++ b/graphgen/models/generator/vqa_generator.py @@ -77,8 +77,8 @@ class VQAGenerator(BaseGenerator): nodes, _ = batch for node in nodes: node_data = node[1] - if "images" in node_data and node_data["images"]: - img_path = node_data["images"]["img_path"] + if "image_data" in node_data and node_data["image_data"]: + img_path = node_data["image_data"]["img_path"] for qa in qa_pairs.values(): qa["img_path"] = img_path result.update(qa_pairs) diff --git a/graphgen/models/llm/local/sglang_wrapper.py b/graphgen/models/llm/local/sglang_wrapper.py index 20400fdf3b63add41b8b7a36d3eb5c8d64b4094d..e8648613eb0775bb47e221c17f80594ee02840c2 100644 --- a/graphgen/models/llm/local/sglang_wrapper.py +++ b/graphgen/models/llm/local/sglang_wrapper.py @@ -138,15 +138,3 @@ class SGLangWrapper(BaseLLMWrapper): raise NotImplementedError( "SGLangWrapper does not support per-token logprobs yet." ) - - def shutdown(self) -> None: - """Gracefully shutdown the SGLang engine.""" - if hasattr(self, "engine"): - self.engine.shutdown() - - def restart(self) -> None: - """Restart the SGLang engine.""" - self.shutdown() - self.engine = self.engine.__class__( - model_path=self.model_path, tp_size=self.tp_size - ) diff --git a/graphgen/models/llm/local/vllm_wrapper.py b/graphgen/models/llm/local/vllm_wrapper.py index d3f6cfccf2ac349bb409d34520fca52c02dbcc5e..b8d8a6de37a620bd27f06e4c7cf6377d8d841d20 100644 --- a/graphgen/models/llm/local/vllm_wrapper.py +++ b/graphgen/models/llm/local/vllm_wrapper.py @@ -1,3 +1,5 @@ +import math +import uuid from typing import Any, List, Optional from graphgen.bases.base_llm_wrapper import BaseLLMWrapper @@ -6,7 +8,7 @@ from graphgen.bases.datatypes import Token class VLLMWrapper(BaseLLMWrapper): """ - Async inference backend based on vLLM (https://github.com/vllm-project/vllm) + Async inference backend based on vLLM. """ def __init__( @@ -20,12 +22,11 @@ class VLLMWrapper(BaseLLMWrapper): **kwargs: Any, ): super().__init__(temperature=temperature, top_p=top_p, **kwargs) - try: from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams except ImportError as exc: raise ImportError( - "VLLMWrapper requires vllm. Install it with: uv pip install vllm --torch-backend=auto" + "VLLMWrapper requires vllm. Install it with: uv pip install vllm" ) from exc self.SamplingParams = SamplingParams @@ -35,9 +36,9 @@ class VLLMWrapper(BaseLLMWrapper): tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=gpu_memory_utilization, trust_remote_code=kwargs.get("trust_remote_code", True), + disable_log_stats=False, ) self.engine = AsyncLLMEngine.from_engine_args(engine_args) - self.temperature = temperature self.top_p = top_p self.topk = topk @@ -60,6 +61,7 @@ class VLLMWrapper(BaseLLMWrapper): self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> str: full_prompt = self._build_inputs(text, history) + request_id = f"graphgen_req_{uuid.uuid4()}" sp = self.SamplingParams( temperature=self.temperature if self.temperature > 0 else 1.0, @@ -67,71 +69,57 @@ class VLLMWrapper(BaseLLMWrapper): max_tokens=extra.get("max_new_tokens", 512), ) - results = [] - async for req_output in self.engine.generate( - full_prompt, sp, request_id="graphgen_req" - ): - results = req_output.outputs - return results[-1].text + result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) + + final_output = None + async for request_output in result_generator: + final_output = request_output + + if not final_output or not final_output.outputs: + return "" + + return final_output.outputs[0].text async def generate_topk_per_token( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: full_prompt = self._build_inputs(text, history) + request_id = f"graphgen_topk_{uuid.uuid4()}" + sp = self.SamplingParams( temperature=0, max_tokens=1, logprobs=self.topk, ) - results = [] - async for req_output in self.engine.generate( - full_prompt, sp, request_id="graphgen_topk" + result_generator = self.engine.generate(full_prompt, sp, request_id=request_id) + + final_output = None + async for request_output in result_generator: + final_output = request_output + + if ( + not final_output + or not final_output.outputs + or not final_output.outputs[0].logprobs ): - results = req_output.outputs - top_logprobs = results[-1].logprobs[0] + return [] + + top_logprobs = final_output.outputs[0].logprobs[0] tokens = [] for _, logprob_obj in top_logprobs.items(): tok_str = logprob_obj.decoded_token - prob = float(logprob_obj.logprob.exp()) + prob = float(math.exp(logprob_obj.logprob)) tokens.append(Token(tok_str, prob)) + tokens.sort(key=lambda x: -x.prob) return tokens async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None, **extra: Any ) -> List[Token]: - full_prompt = self._build_inputs(text, history) - - # vLLM 没有现成的“mask 一个 token 再算 prob”接口, - # 我们采用最直观的方式:把 prompt 一次性送进去,打开 - # prompt_logprobs=True,让 vLLM 返回 *输入部分* 每个位置的 - # logprob,然后挑出对应 token 的概率即可。 - sp = self.SamplingParams( - temperature=0, - max_tokens=0, # 不生成新 token - prompt_logprobs=1, # 只要 top-1 就够了 + raise NotImplementedError( + "VLLMWrapper does not support per-token logprobs yet." ) - - results = [] - async for req_output in self.engine.generate( - full_prompt, sp, request_id="graphgen_prob" - ): - results = req_output.outputs - - # prompt_logprobs 是一个 list,长度 = prompt token 数, - # 每个元素是 dict{token_id: logprob_obj} 或 None(首个位置为 None) - prompt_logprobs = results[-1].prompt_logprobs - - tokens = [] - for _, logprob_dict in enumerate(prompt_logprobs): - if logprob_dict is None: - continue - # 这里每个 dict 只有 1 个 kv,因为 top-1 - _, logprob_obj = next(iter(logprob_dict.items())) - tok_str = logprob_obj.decoded_token - prob = float(logprob_obj.logprob.exp()) - tokens.append(Token(tok_str, prob)) - return tokens diff --git a/graphgen/models/partitioner/anchor_bfs_partitioner.py b/graphgen/models/partitioner/anchor_bfs_partitioner.py index 6cc1400c13f1b0824b7fcb141b06544001862b81..09133af73221297c26eff58d90c30999bd3780e7 100644 --- a/graphgen/models/partitioner/anchor_bfs_partitioner.py +++ b/graphgen/models/partitioner/anchor_bfs_partitioner.py @@ -1,6 +1,6 @@ import random from collections import deque -from typing import Any, List, Literal, Set, Tuple +from typing import Any, Iterable, List, Literal, Set, Tuple from graphgen.bases import BaseGraphStorage from graphgen.bases.datatypes import Community @@ -30,24 +30,23 @@ class AnchorBFSPartitioner(BFSPartitioner): self.anchor_type = anchor_type self.anchor_ids = anchor_ids - async def partition( + def partition( self, g: BaseGraphStorage, max_units_per_community: int = 1, **kwargs: Any, - ) -> List[Community]: + ) -> Iterable[Community]: nodes = g.get_all_nodes() # List[tuple[id, meta]] edges = g.get_all_edges() # List[tuple[u, v, meta]] adj, _ = self._build_adjacency_list(nodes, edges) - anchors: Set[str] = await self._pick_anchor_ids(nodes) + anchors: Set[str] = self._pick_anchor_ids(nodes) if not anchors: - return [] # if no anchors, return empty list + return # if no anchors, return nothing used_n: set[str] = set() used_e: set[frozenset[str]] = set() - communities: List[Community] = [] seeds = list(anchors) random.shuffle(seeds) @@ -55,17 +54,13 @@ class AnchorBFSPartitioner(BFSPartitioner): for seed_node in seeds: if seed_node in used_n: continue - comm_n, comm_e = await self._grow_community( + comm_n, comm_e = self._grow_community( seed_node, adj, max_units_per_community, used_n, used_e ) if comm_n or comm_e: - communities.append( - Community(id=len(communities), nodes=comm_n, edges=comm_e) - ) + yield Community(id=seed_node, nodes=comm_n, edges=comm_e) - return communities - - async def _pick_anchor_ids( + def _pick_anchor_ids( self, nodes: List[tuple[str, dict]], ) -> Set[str]: @@ -80,7 +75,7 @@ class AnchorBFSPartitioner(BFSPartitioner): return anchor_ids @staticmethod - async def _grow_community( + def _grow_community( seed: str, adj: dict[str, List[str]], max_units: int, diff --git a/graphgen/models/partitioner/bfs_partitioner.py b/graphgen/models/partitioner/bfs_partitioner.py index 008957121e9f0667e2493fbf18a459b947a18819..994e08e812a7ce9f7355a3b73d79b6fd5214e6a7 100644 --- a/graphgen/models/partitioner/bfs_partitioner.py +++ b/graphgen/models/partitioner/bfs_partitioner.py @@ -1,6 +1,6 @@ import random from collections import deque -from typing import Any, List +from typing import Any, Iterable, List from graphgen.bases import BaseGraphStorage, BasePartitioner from graphgen.bases.datatypes import Community @@ -17,12 +17,12 @@ class BFSPartitioner(BasePartitioner): (A unit is a node or an edge.) """ - async def partition( + def partition( self, g: BaseGraphStorage, max_units_per_community: int = 1, **kwargs: Any, - ) -> List[Community]: + ) -> Iterable[Community]: nodes = g.get_all_nodes() edges = g.get_all_edges() @@ -30,7 +30,6 @@ class BFSPartitioner(BasePartitioner): used_n: set[str] = set() used_e: set[frozenset[str]] = set() - communities: List[Community] = [] units = [(NODE_UNIT, n[0]) for n in nodes] + [ (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges @@ -74,8 +73,4 @@ class BFSPartitioner(BasePartitioner): queue.append((NODE_UNIT, n)) if comm_n or comm_e: - communities.append( - Community(id=len(communities), nodes=comm_n, edges=comm_e) - ) - - return communities + yield Community(id=seed, nodes=comm_n, edges=comm_e) diff --git a/graphgen/models/partitioner/dfs_partitioner.py b/graphgen/models/partitioner/dfs_partitioner.py index 6c394b10bfd47bee701d0fcc6588e73d1ca67ecd..4d93ad7f2e5d2523c3a359f3f982bb22ce914409 100644 --- a/graphgen/models/partitioner/dfs_partitioner.py +++ b/graphgen/models/partitioner/dfs_partitioner.py @@ -1,5 +1,6 @@ import random -from typing import Any, List +from collections.abc import Iterable +from typing import Any from graphgen.bases import BaseGraphStorage, BasePartitioner from graphgen.bases.datatypes import Community @@ -16,12 +17,12 @@ class DFSPartitioner(BasePartitioner): (In GraphGen, a unit is defined as a node or an edge.) """ - async def partition( + def partition( self, g: BaseGraphStorage, max_units_per_community: int = 1, **kwargs: Any, - ) -> List[Community]: + ) -> Iterable[Community]: nodes = g.get_all_nodes() edges = g.get_all_edges() @@ -29,7 +30,6 @@ class DFSPartitioner(BasePartitioner): used_n: set[str] = set() used_e: set[frozenset[str]] = set() - communities: List[Community] = [] units = [(NODE_UNIT, n[0]) for n in nodes] + [ (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges @@ -71,8 +71,4 @@ class DFSPartitioner(BasePartitioner): stack.append((NODE_UNIT, n)) if comm_n or comm_e: - communities.append( - Community(id=len(communities), nodes=comm_n, edges=comm_e) - ) - - return communities + yield Community(id=seed, nodes=comm_n, edges=comm_e) diff --git a/graphgen/models/partitioner/ece_partitioner.py b/graphgen/models/partitioner/ece_partitioner.py index 7de73181888ee31e644bfeeeda92501cf1af2b1e..fcf776c71ed5530d08145967685fb7a2a44506ab 100644 --- a/graphgen/models/partitioner/ece_partitioner.py +++ b/graphgen/models/partitioner/ece_partitioner.py @@ -1,8 +1,8 @@ -import asyncio import random -from typing import Any, Dict, List, Optional, Set, Tuple +from collections import deque +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple -from tqdm.asyncio import tqdm as tqdm_async +from tqdm import tqdm from graphgen.bases import BaseGraphStorage from graphgen.bases.datatypes import Community @@ -51,7 +51,7 @@ class ECEPartitioner(BFSPartitioner): raise ValueError(f"Invalid edge sampling: {edge_sampling}") return units - async def partition( + def partition( self, g: BaseGraphStorage, max_units_per_community: int = 10, @@ -59,7 +59,7 @@ class ECEPartitioner(BFSPartitioner): max_tokens_per_community: int = 10240, unit_sampling: str = "random", **kwargs: Any, - ) -> List[Community]: + ) -> Iterable[Community]: nodes: List[Tuple[str, dict]] = g.get_all_nodes() edges: List[Tuple[str, str, dict]] = g.get_all_edges() @@ -73,21 +73,18 @@ class ECEPartitioner(BFSPartitioner): used_n: Set[str] = set() used_e: Set[frozenset[str]] = set() - communities: List = [] all_units = self._sort_units(all_units, unit_sampling) - async def _grow_community( - seed_unit: Tuple[str, Any, dict] - ) -> Optional[Community]: + def _grow_community(seed_unit: Tuple[str, Any, dict]) -> Optional[Community]: nonlocal used_n, used_e community_nodes: Dict[str, dict] = {} community_edges: Dict[frozenset[str], dict] = {} - queue: asyncio.Queue = asyncio.Queue() + queue = deque() token_sum = 0 - async def _add_unit(u): + def _add_unit(u): nonlocal token_sum t, i, d = u if t == NODE_UNIT: # node @@ -103,11 +100,11 @@ class ECEPartitioner(BFSPartitioner): token_sum += d.get("length", 0) return True - await _add_unit(seed_unit) - await queue.put(seed_unit) + _add_unit(seed_unit) + queue.append(seed_unit) # BFS - while not queue.empty(): + while queue: if ( len(community_nodes) + len(community_edges) >= max_units_per_community @@ -115,7 +112,7 @@ class ECEPartitioner(BFSPartitioner): ): break - cur_type, cur_id, _ = await queue.get() + cur_type, cur_id, _ = queue.popleft() neighbors: List[Tuple[str, Any, dict]] = [] if cur_type == NODE_UNIT: @@ -136,26 +133,24 @@ class ECEPartitioner(BFSPartitioner): or token_sum >= max_tokens_per_community ): break - if await _add_unit(nb): - await queue.put(nb) + if _add_unit(nb): + queue.append(nb) if len(community_nodes) + len(community_edges) < min_units_per_community: return None return Community( - id=len(communities), + id=seed_unit[1], nodes=list(community_nodes.keys()), edges=[(u, v) for (u, v), _ in community_edges.items()], ) - async for unit in tqdm_async(all_units, desc="ECE partition"): + for unit in tqdm(all_units, desc="ECE partition"): utype, uid, _ = unit if (utype == NODE_UNIT and uid in used_n) or ( utype == EDGE_UNIT and uid in used_e ): continue - comm = await _grow_community(unit) - if comm is not None: - communities.append(comm) - - return communities + comm = _grow_community(unit) + if comm: + yield comm diff --git a/graphgen/models/partitioner/leiden_partitioner.py b/graphgen/models/partitioner/leiden_partitioner.py index 1f85789be99282d3ee7f6008c40297b55873845a..b62b8544c6f771db769effb8ef46d414e7e1f6c1 100644 --- a/graphgen/models/partitioner/leiden_partitioner.py +++ b/graphgen/models/partitioner/leiden_partitioner.py @@ -13,7 +13,7 @@ class LeidenPartitioner(BasePartitioner): Leiden partitioner that partitions the graph into communities using the Leiden algorithm. """ - async def partition( + def partition( self, g: BaseGraphStorage, max_size: int = 20, @@ -37,12 +37,10 @@ class LeidenPartitioner(BasePartitioner): nodes = g.get_all_nodes() # List[Tuple[str, dict]] edges = g.get_all_edges() # List[Tuple[str, str, dict]] - node2cid: Dict[str, int] = await self._run_leiden( - nodes, edges, use_lcc, random_seed - ) + node2cid: Dict[str, int] = self._run_leiden(nodes, edges, use_lcc, random_seed) if max_size is not None and max_size > 0: - node2cid = await self._split_communities(node2cid, max_size) + node2cid = self._split_communities(node2cid, max_size) cid2nodes: Dict[int, List[str]] = defaultdict(list) for n, cid in node2cid.items(): @@ -58,7 +56,7 @@ class LeidenPartitioner(BasePartitioner): return communities @staticmethod - async def _run_leiden( + def _run_leiden( nodes: List[Tuple[str, dict]], edges: List[Tuple[str, str, dict]], use_lcc: bool = False, @@ -92,9 +90,7 @@ class LeidenPartitioner(BasePartitioner): return node2cid @staticmethod - async def _split_communities( - node2cid: Dict[str, int], max_size: int - ) -> Dict[str, int]: + def _split_communities(node2cid: Dict[str, int], max_size: int) -> Dict[str, int]: """ Split communities larger than max_size into smaller sub-communities. """ diff --git a/graphgen/models/reader/__init__.py b/graphgen/models/reader/__init__.py index 600ffb4a15d62fb490b2d08e94445b49819cde98..220460c32a5f5584b568f0d214951938bb0e8e9f 100644 --- a/graphgen/models/reader/__init__.py +++ b/graphgen/models/reader/__init__.py @@ -1,6 +1,5 @@ from .csv_reader import CSVReader from .json_reader import JSONReader -from .jsonl_reader import JSONLReader from .parquet_reader import ParquetReader from .pdf_reader import PDFReader from .pickle_reader import PickleReader diff --git a/graphgen/models/reader/csv_reader.py b/graphgen/models/reader/csv_reader.py index bc865a3b9f4a1ff273136da27163b4defb036ee1..a0343d970e027100ca20e00da96d0a2b1d35f71a 100644 --- a/graphgen/models/reader/csv_reader.py +++ b/graphgen/models/reader/csv_reader.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List +from typing import List, Union -import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader @@ -13,13 +14,15 @@ class CSVReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: + def read(self, input_path: Union[str, List[str]]) -> Dataset: + """ + Read CSV files and return Ray Dataset. - df = pd.read_csv(file_path) - for _, row in df.iterrows(): - assert "type" in row, f"Missing 'type' column in document: {row.to_dict()}" - if row["type"] == "text" and self.text_column not in row: - raise ValueError( - f"Missing '{self.text_column}' in document: {row.to_dict()}" - ) - return self.filter(df.to_dict(orient="records")) + :param input_path: Path to CSV file or list of CSV files. + :return: Ray Dataset containing validated and filtered data. + """ + + ds = ray.data.read_csv(input_path) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/json_reader.py b/graphgen/models/reader/json_reader.py index 8253041c2fa7b59f125192896d8942008d35d076..6752e0422616b34a9436fcec38bbeff37b58ede5 100644 --- a/graphgen/models/reader/json_reader.py +++ b/graphgen/models/reader/json_reader.py @@ -1,26 +1,53 @@ import json -from typing import Any, Dict, List +from typing import List, Union + +import ray +import ray.data from graphgen.bases.base_reader import BaseReader class JSONReader(BaseReader): """ - Reader for JSON files. + Reader for JSON and JSONL files. Columns: - type: The type of the document (e.g., "text", "image", etc.) - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "r", encoding="utf-8") as f: - data = json.load(f) - if isinstance(data, list): - for doc in data: - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError( - f"Missing '{self.text_column}' in document: {doc}" - ) - return self.filter(data) - raise ValueError("JSON file must contain a list of documents.") + def read(self, input_path: Union[str, List[str]]) -> ray.data.Dataset: + """ + Read JSON file and return Ray Dataset. + :param input_path: Path to JSON/JSONL file or list of JSON/JSONL files. + :return: Ray Dataset containing validated and filtered data. + """ + if self.modalities and len(self.modalities) >= 2: + ds: ray.data.Dataset = ray.data.from_items([]) + for file in input_path if isinstance(input_path, list) else [input_path]: + data = [] + if file.endswith(".jsonl"): + with open(file, "r", encoding="utf-8") as f: + for line in f: + item = json.loads(line) + data.append(item) + else: + with open(file, "r", encoding="utf-8") as f: + data = json.load(f) + data = self._unify_schema(data) + file_ds: ray.data.Dataset = ray.data.from_items(data) + ds = ds.union(file_ds) # type: ignore + else: + ds = ray.data.read_json(input_path) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds + + @staticmethod + def _unify_schema(data): + """ + Unify schema for JSON data. + """ + for item in data: + if "content" in item and isinstance(item["content"], dict): + item["content"] = json.dumps(item["content"]) + return data diff --git a/graphgen/models/reader/jsonl_reader.py b/graphgen/models/reader/jsonl_reader.py deleted file mode 100644 index 31bc319582c08491725c20755b0878e29df14d43..0000000000000000000000000000000000000000 --- a/graphgen/models/reader/jsonl_reader.py +++ /dev/null @@ -1,30 +0,0 @@ -import json -from typing import Any, Dict, List - -from graphgen.bases.base_reader import BaseReader -from graphgen.utils import logger - - -class JSONLReader(BaseReader): - """ - Reader for JSONL files. - Columns: - - type: The type of the document (e.g., "text", "image", etc.) - - if type is "text", "content" column must be present. - """ - - def read(self, file_path: str) -> List[Dict[str, Any]]: - docs = [] - with open(file_path, "r", encoding="utf-8") as f: - for line in f: - try: - doc = json.loads(line) - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError( - f"Missing '{self.text_column}' in document: {doc}" - ) - docs.append(doc) - except json.JSONDecodeError as e: - logger.error("Error decoding JSON line: %s. Error: %s", line, e) - return self.filter(docs) diff --git a/graphgen/models/reader/parquet_reader.py b/graphgen/models/reader/parquet_reader.py index a325b876eb85a4d58615479b0907f36e83a0b8f0..dd289e318cef8df74c5d45c9a4abc51424591afd 100644 --- a/graphgen/models/reader/parquet_reader.py +++ b/graphgen/models/reader/parquet_reader.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, List +from typing import List, Union -import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader @@ -13,12 +14,17 @@ class ParquetReader(BaseReader): - if type is "text", "content" column must be present. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - df = pd.read_parquet(file_path) - data: List[Dict[str, Any]] = df.to_dict(orient="records") + def read(self, input_path: Union[str, List[str]]) -> Dataset: + """ + Read Parquet files using Ray Data. - for doc in data: - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError(f"Missing '{self.text_column}' in document: {doc}") - return self.filter(data) + :param input_path: Path to Parquet file or list of Parquet files. + :return: Ray Dataset containing validated documents. + """ + if not ray.is_initialized(): + ray.init() + + ds = ray.data.read_parquet(input_path) + ds = ds.map_batches(self._validate_batch, batch_format="pandas") + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/pdf_reader.py b/graphgen/models/reader/pdf_reader.py index 94562cb5df50c5f57dc50f423177ba79603e42ac..55dab30b3790f6d23755e88da80813863a907587 100644 --- a/graphgen/models/reader/pdf_reader.py +++ b/graphgen/models/reader/pdf_reader.py @@ -5,6 +5,9 @@ import tempfile from pathlib import Path from typing import Any, Dict, List, Optional, Union +import ray +from ray.data import Dataset + from graphgen.bases.base_reader import BaseReader from graphgen.models.reader.txt_reader import TXTReader from graphgen.utils import logger, pick_device @@ -62,19 +65,31 @@ class PDFReader(BaseReader): self.parser = MinerUParser() self.txt_reader = TXTReader() - def read(self, file_path: str, **override) -> List[Dict[str, Any]]: - """ - file_path - **override: override MinerU parameters - """ - pdf_path = Path(file_path).expanduser().resolve() - if not pdf_path.is_file(): - raise FileNotFoundError(pdf_path) + def read( + self, + input_path: Union[str, List[str]], + **override, + ) -> Dataset: + + # Ensure input_path is a list + if isinstance(input_path, str): + input_path = [input_path] + + paths_ds = ray.data.from_items(input_path) + + def process_pdf(row: Dict[str, Any]) -> List[Dict[str, Any]]: + try: + pdf_path = row["item"] + kwargs = {**self._default_kwargs, **override} + return self._call_mineru(Path(pdf_path), kwargs) + except Exception as e: + logger.error("Failed to process %s: %s", row, e) + return [] - kwargs = {**self._default_kwargs, **override} + docs_ds = paths_ds.flat_map(process_pdf) + docs_ds = docs_ds.filter(self._should_keep_item) - mineru_result = self._call_mineru(pdf_path, kwargs) - return self.filter(mineru_result) + return docs_ds def _call_mineru( self, pdf_path: Path, kwargs: Dict[str, Any] @@ -161,18 +176,18 @@ class MinerUParser: base = os.path.dirname(json_file) results = [] - for item in data: + for it in data: for key in ("img_path", "table_img_path", "equation_img_path"): - rel_path = item.get(key) + rel_path = it.get(key) if rel_path: - item[key] = str(Path(base).joinpath(rel_path).resolve()) - if item["type"] == "text": - item["content"] = item["text"] - del item["text"] + it[key] = str(Path(base).joinpath(rel_path).resolve()) + if it["type"] == "text": + it["content"] = it["text"] + del it["text"] for key in ("page_idx", "bbox", "text_level"): - if item.get(key) is not None: - del item[key] - results.append(item) + if it.get(key) is not None: + del it[key] + results.append(it) return results @staticmethod diff --git a/graphgen/models/reader/pickle_reader.py b/graphgen/models/reader/pickle_reader.py index 1a11dc1163fb8ae01f9269dce7b2f98425b6f5fe..6e3d1949c50b2fab564da643783e094190356aa3 100644 --- a/graphgen/models/reader/pickle_reader.py +++ b/graphgen/models/reader/pickle_reader.py @@ -1,30 +1,78 @@ import pickle -from typing import Any, Dict, List +from typing import List, Union + +import pandas as pd +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader +from graphgen.utils import logger class PickleReader(BaseReader): """ - Read pickle files, requiring the top-level object to be List[Dict[str, Any]]. - - Columns: + Read pickle files, requiring the schema to be restored to List[Dict[str, Any]]. + Each pickle file should contain a list of dictionaries with at least: - type: The type of the document (e.g., "text", "image", etc.) - if type is "text", "content" column must be present. + + Note: Uses ray.data.read_binary_files as ray.data.read_pickle is not available. + For Ray >= 2.5, consider using read_pickle if available in your version. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "rb") as f: - data = pickle.load(f) + def read( + self, + input_path: Union[str, List[str]], + ) -> Dataset: + """ + Read Pickle files using Ray Data. + + :param input_path: Path to pickle file or list of pickle files. + :return: Ray Dataset containing validated documents. + """ + if not ray.is_initialized(): + ray.init() + + # Use read_binary_files as a reliable alternative to read_pickle + ds = ray.data.read_binary_files(input_path, include_paths=True) + + # Deserialize pickle files and flatten into individual records + def deserialize_batch(batch: pd.DataFrame) -> pd.DataFrame: + all_records = [] + for _, row in batch.iterrows(): + try: + # Load pickle data from bytes + data = pickle.loads(row["bytes"]) + + # Validate structure + if not isinstance(data, list): + logger.error( + "Pickle file {row['path']} must contain a list, got {type(data)}" + ) + continue + + if not all(isinstance(item, dict) for item in data): + logger.error( + "Pickle file {row['path']} must contain a list of dictionaries" + ) + continue + + # Flatten: each dict in the list becomes a separate row + all_records.extend(data) + except Exception as e: + logger.error( + "Failed to deserialize pickle file %s: %s", row["path"], str(e) + ) + continue + + return pd.DataFrame(all_records) - if not isinstance(data, list): - raise ValueError("Pickle file must contain a list of documents.") + # Apply deserialization and flattening + ds = ds.map_batches(deserialize_batch, batch_format="pandas") - for doc in data: - if not isinstance(doc, dict): - raise ValueError("Every item in the list must be a dict.") - assert "type" in doc, f"Missing 'type' in document: {doc}" - if doc.get("type") == "text" and self.text_column not in doc: - raise ValueError(f"Missing '{self.text_column}' in document: {doc}") + # Validate the schema + ds = ds.map_batches(self._validate_batch, batch_format="pandas") - return self.filter(data) + # Filter valid items + ds = ds.filter(self._should_keep_item) + return ds diff --git a/graphgen/models/reader/rdf_reader.py b/graphgen/models/reader/rdf_reader.py index cce167c1fad03c31733100587a522242781776b0..9670107a1114d68fbac2415b27897507adaa9736 100644 --- a/graphgen/models/reader/rdf_reader.py +++ b/graphgen/models/reader/rdf_reader.py @@ -1,48 +1,128 @@ -from typing import Any, Dict, List +from pathlib import Path +from typing import Any, Dict, List, Union +import ray import rdflib +from ray.data import Dataset from rdflib import Literal from rdflib.util import guess_format from graphgen.bases.base_reader import BaseReader +from graphgen.utils import logger class RDFReader(BaseReader): """ Reader for RDF files that extracts triples and represents them as dictionaries. + + Uses Ray Data for distributed processing of multiple RDF files. """ - def read(self, file_path: str) -> List[Dict[str, Any]]: + def __init__(self, *, text_column: str = "content", **kwargs): + """ + Initialize RDFReader. + + :param text_column: The column name for text content (default: "content"). + """ + super().__init__(**kwargs) + self.text_column = text_column + + def read( + self, + input_path: Union[str, List[str]], + ) -> Dataset: + """ + Read RDF file(s) using Ray Data. + + :param input_path: Path to RDF file or list of RDF files. + :return: Ray Dataset containing extracted documents. + """ + if not ray.is_initialized(): + ray.init() + + # Ensure input_path is a list to prevent Ray from splitting string into characters + if isinstance(input_path, str): + input_path = [input_path] + + # Create dataset from file paths + paths_ds = ray.data.from_items(input_path) + + def process_rdf(row: Dict[str, Any]) -> List[Dict[str, Any]]: + """Process a single RDF file and return list of documents.""" + try: + file_path = row["item"] + return self._parse_rdf_file(Path(file_path)) + except Exception as e: + logger.error( + "Failed to process RDF file %s: %s", row.get("item", "unknown"), e + ) + return [] + + # Process files in parallel and flatten results + docs_ds = paths_ds.flat_map(process_rdf) + + # Filter valid documents + docs_ds = docs_ds.filter(self._should_keep_item) + + return docs_ds + + def _parse_rdf_file(self, file_path: Path) -> List[Dict[str, Any]]: + """ + Parse a single RDF file and extract documents. + + :param file_path: Path to RDF file. + :return: List of document dictionaries. + """ + if not file_path.is_file(): + raise FileNotFoundError(f"RDF file not found: {file_path}") + g = rdflib.Graph() - fmt = guess_format(file_path) + fmt = guess_format(str(file_path)) + try: - g.parse(file_path, format=fmt) + g.parse(str(file_path), format=fmt) except Exception as e: raise ValueError(f"Cannot parse RDF file {file_path}: {e}") from e docs: List[Dict[str, Any]] = [] - text_col = self.text_column + # Process each unique subject in the RDF graph for subj in set(g.subjects()): literals = [] props = {} + + # Extract all triples for this subject for _, pred, obj in g.triples((subj, None, None)): pred_str = str(pred) + obj_str = str(obj) + + # Collect literal values as text content if isinstance(obj, Literal): - literals.append(str(obj)) - props.setdefault(pred_str, []).append(str(obj)) + literals.append(obj_str) + + # Store all properties (including non-literals) + props.setdefault(pred_str, []).append(obj_str) + # Join all literal values as the text content text = " ".join(literals).strip() if not text: - raise ValueError( - f"Subject {subj} has no literal values; " - f"missing '{text_col}' for text column." + logger.warning( + "Subject %s in %s has no literal values; document will have empty '%s' field.", + subj, + file_path, + self.text_column, ) - doc = {"id": str(subj), text_col: text, "properties": props} + # Create document dictionary + doc = { + "id": str(subj), + self.text_column: text, + "properties": props, + "source_file": str(file_path), + } docs.append(doc) if not docs: - raise ValueError("RDF file contains no valid documents.") + logger.warning("RDF file %s contains no valid documents.", file_path) - return self.filter(docs) + return docs diff --git a/graphgen/models/reader/txt_reader.py b/graphgen/models/reader/txt_reader.py index ec2ff74747d4899c9a3b10f5f1fb5a01c9676905..51a47de2f5d5972c6d85cfe45753200a67296c39 100644 --- a/graphgen/models/reader/txt_reader.py +++ b/graphgen/models/reader/txt_reader.py @@ -1,10 +1,32 @@ -from typing import Any, Dict, List +from typing import List, Union + +import ray +from ray.data import Dataset from graphgen.bases.base_reader import BaseReader class TXTReader(BaseReader): - def read(self, file_path: str) -> List[Dict[str, Any]]: - with open(file_path, "r", encoding="utf-8") as f: - docs = [{"type": "text", self.text_column: f.read()}] - return self.filter(docs) + def read( + self, + input_path: Union[str, List[str]], + ) -> Dataset: + """ + Read text files from the specified input path. + :param input_path: Path to the input text file or list of text files. + :return: Ray Dataset containing the read text data. + """ + docs_ds = ray.data.read_binary_files( + input_path, + include_paths=False, + ) + + docs_ds = docs_ds.map( + lambda row: { + "type": "text", + self.text_column: row["bytes"].decode("utf-8"), + } + ) + + docs_ds = docs_ds.filter(self._should_keep_item) + return docs_ds diff --git a/graphgen/models/splitter/character_splitter.py b/graphgen/models/splitter/character_splitter.py index 1c91877e3ea0f384e14edb324ea171876e014b55..8877c861c797f5ed4efd7cf9edd9fbfdc639742c 100644 --- a/graphgen/models/splitter/character_splitter.py +++ b/graphgen/models/splitter/character_splitter.py @@ -17,7 +17,7 @@ class CharacterSplitter(BaseSplitter): def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" - # First we naively split the large input into a bunch of smaller ones. + # First we naively chunk the large input into a bunch of smaller ones. separator = ( self._separator if self._is_separator_regex else re.escape(self._separator) ) diff --git a/graphgen/models/splitter/markdown_splitter.py b/graphgen/models/splitter/markdown_splitter.py index 03def6ae058e0089230c383a7d876a2c8c6f6194..40b6a44e7f3adb94bf2fdff3fa9854d82d3fb7b9 100644 --- a/graphgen/models/splitter/markdown_splitter.py +++ b/graphgen/models/splitter/markdown_splitter.py @@ -6,12 +6,12 @@ from graphgen.models.splitter.recursive_character_splitter import ( class MarkdownTextRefSplitter(RecursiveCharacterSplitter): - """Attempts to split the text along Markdown-formatted headings.""" + """Attempts to chunk the text along Markdown-formatted headings.""" def __init__(self, **kwargs: Any) -> None: """Initialize a MarkdownTextRefSplitter.""" separators = [ - # First, try to split along Markdown headings (starting with level 2) + # First, try to chunk along Markdown headings (starting with level 2) "\n#{1,6} ", # Note the alternative syntax for headings (below) is not handled here # Heading level 2 diff --git a/graphgen/models/splitter/recursive_character_splitter.py b/graphgen/models/splitter/recursive_character_splitter.py index c9d7c543653f1ea3a89f15a5cd26ba170e61f5b0..b1ee8e0658a3591512bd47cca7a0938af58802a4 100644 --- a/graphgen/models/splitter/recursive_character_splitter.py +++ b/graphgen/models/splitter/recursive_character_splitter.py @@ -7,7 +7,7 @@ from graphgen.bases.base_splitter import BaseSplitter class RecursiveCharacterSplitter(BaseSplitter): """Splitting text by recursively look at characters. - Recursively tries to split by different characters to find one that works. + Recursively tries to chunk by different characters to find one that works. """ def __init__( @@ -88,7 +88,7 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterSplitter): def _split_text_with_regex_from_end( self, text: str, separator: str, keep_separator: bool ) -> List[str]: - # Now that we have the separator, split the text + # Now that we have the separator, chunk the text if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. diff --git a/graphgen/models/storage/__init__.py b/graphgen/models/storage/__init__.py index 1e8f8341951f2bf3d5baeb3efb9033f499be7fcc..889a074c0384ee24dae108e7ff8dade3d9b916a7 100644 --- a/graphgen/models/storage/__init__.py +++ b/graphgen/models/storage/__init__.py @@ -1,3 +1,6 @@ -from .json_storage import JsonKVStorage, JsonListStorage -from .networkx_storage import NetworkXStorage +from graphgen.models.storage.graph.kuzu_storage import KuzuStorage +from graphgen.models.storage.graph.networkx_storage import NetworkXStorage +from graphgen.models.storage.kv.json_storage import JsonKVStorage +from graphgen.models.storage.kv.rocksdb_storage import RocksDBKVStorage + from .rocksdb_cache import RocksDBCache diff --git a/graphgen/configs/__init__.py b/graphgen/models/storage/graph/__init__.py similarity index 100% rename from graphgen/configs/__init__.py rename to graphgen/models/storage/graph/__init__.py diff --git a/graphgen/models/storage/graph/kuzu_storage.py b/graphgen/models/storage/graph/kuzu_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..4a221b8e486720b5745155dc0a7f6a8b910174d1 --- /dev/null +++ b/graphgen/models/storage/graph/kuzu_storage.py @@ -0,0 +1,256 @@ +import json +import os +import shutil +from dataclasses import dataclass +from typing import Any + +try: + import kuzu +except ImportError: + kuzu = None + +from graphgen.bases.base_storage import BaseGraphStorage + + +@dataclass +class KuzuStorage(BaseGraphStorage): + """ + Graph storage implementation based on KuzuDB. + Since KuzuDB is a structured graph database and GraphGen uses dynamic dictionaries for properties, + we map the data to a generic schema: + - Node Table 'Entity': {id: STRING, data: STRING (JSON)} + - Rel Table 'Relation': {FROM Entity TO Entity, data: STRING (JSON)} + """ + + working_dir: str = None + namespace: str = None + _db: Any = None + _conn: Any = None + + def __post_init__(self): + if kuzu is None: + raise ImportError( + "KuzuDB is not installed. Please install it via `pip install kuzu`." + ) + + self.db_path = os.path.join(self.working_dir, f"{self.namespace}_kuzu") + self._init_db() + + def _init_db(self): + # KuzuDB automatically creates the directory + self._db = kuzu.Database(self.db_path) + self._conn = kuzu.Connection(self._db) + self._init_schema() + print(f"KuzuDB initialized at {self.db_path}") + + def _init_schema(self): + """Initialize the generic Node and Edge tables if they don't exist.""" + # Check and create Node table + try: + # We use a generic table name "Entity" to store all nodes + self._conn.execute( + "CREATE NODE TABLE Entity(id STRING, data STRING, PRIMARY KEY(id))" + ) + print("Created KuzuDB Node Table 'Entity'") + except RuntimeError as e: + # Usually throws if table exists, verify safely or ignore + print("Node Table 'Entity' already exists or error:", e) + + # Check and create Edge table + try: + # We use a generic table name "Relation" to store all edges + self._conn.execute( + "CREATE REL TABLE Relation(FROM Entity TO Entity, data STRING)" + ) + print("Created KuzuDB Rel Table 'Relation'") + except RuntimeError as e: + print("Rel Table 'Relation' already exists or error:", e) + + def index_done_callback(self): + """KuzuDB is ACID, changes are immediate, but we can verify generic persistence here.""" + + def has_node(self, node_id: str) -> bool: + result = self._conn.execute( + "MATCH (a:Entity {id: $id}) RETURN count(a)", {"id": node_id} + ) + count = result.get_next()[0] + return count > 0 + + def has_edge(self, source_node_id: str, target_node_id: str): + result = self._conn.execute( + "MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) RETURN count(e)", + {"src": source_node_id, "dst": target_node_id}, + ) + count = result.get_next()[0] + return count > 0 + + def node_degree(self, node_id: str) -> int: + # Calculate total degree (incoming + outgoing) + query = """ + MATCH (a:Entity {id: $id})-[e:Relation]-(b:Entity) + RETURN count(e) + """ + result = self._conn.execute(query, {"id": node_id}) + if result.has_next(): + return result.get_next()[0] + return 0 + + def edge_degree(self, src_id: str, tgt_id: str) -> int: + # In this context, usually checks existence or multiplicity. + # Kuzu supports multi-edges, so we count them. + query = """ + MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) + RETURN count(e) + """ + result = self._conn.execute(query, {"src": src_id, "dst": tgt_id}) + if result.has_next(): + return result.get_next()[0] + return 0 + + def get_node(self, node_id: str) -> Any: + result = self._conn.execute( + "MATCH (a:Entity {id: $id}) RETURN a.data", {"id": node_id} + ) + if result.has_next(): + data_str = result.get_next()[0] + return json.loads(data_str) if data_str else {} + return None + + def update_node(self, node_id: str, node_data: dict[str, str]): + current_data = self.get_node(node_id) + if current_data is None: + print(f"Node {node_id} not found for update.") + return + + # Merge existing data with new data + current_data.update(node_data) + json_data = json.dumps(current_data, ensure_ascii=False) + + self._conn.execute( + "MATCH (a:Entity {id: $id}) SET a.data = $data", + {"id": node_id, "data": json_data}, + ) + + def get_all_nodes(self) -> Any: + """Returns List[Tuple[id, data_dict]]""" + result = self._conn.execute("MATCH (a:Entity) RETURN a.id, a.data") + nodes = [] + while result.has_next(): + row = result.get_next() + nodes.append((row[0], json.loads(row[1]))) + return nodes + + def get_edge(self, source_node_id: str, target_node_id: str): + # Warning: If multiple edges exist, this returns the first one found + query = """ + MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) + RETURN e.data + """ + result = self._conn.execute( + query, {"src": source_node_id, "dst": target_node_id} + ) + if result.has_next(): + data_str = result.get_next()[0] + return json.loads(data_str) if data_str else {} + return None + + def update_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): + current_data = self.get_edge(source_node_id, target_node_id) + if current_data is None: + print(f"Edge {source_node_id}->{target_node_id} not found for update.") + return + + current_data.update(edge_data) + json_data = json.dumps(current_data, ensure_ascii=False) + + query = """ + MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity {id: $dst}) + SET e.data = $data + """ + self._conn.execute( + query, {"src": source_node_id, "dst": target_node_id, "data": json_data} + ) + + def get_all_edges(self) -> Any: + """Returns List[Tuple[src, dst, data_dict]]""" + query = "MATCH (a:Entity)-[e:Relation]->(b:Entity) RETURN a.id, b.id, e.data" + result = self._conn.execute(query) + edges = [] + while result.has_next(): + row = result.get_next() + edges.append((row[0], row[1], json.loads(row[2]))) + return edges + + def get_node_edges(self, source_node_id: str) -> Any: + """Returns generic edges connected to this node (outgoing)""" + query = """ + MATCH (a:Entity {id: $src})-[e:Relation]->(b:Entity) + RETURN a.id, b.id, e.data + """ + result = self._conn.execute(query, {"src": source_node_id}) + edges = [] + while result.has_next(): + row = result.get_next() + edges.append((row[0], row[1], json.loads(row[2]))) + return edges + + def upsert_node(self, node_id: str, node_data: dict[str, str]): + """ + Insert or Update node. + Kuzu supports MERGE clause (similar to Neo4j) to handle upserts. + """ + json_data = json.dumps(node_data, ensure_ascii=False) + query = """ + MERGE (a:Entity {id: $id}) + ON MATCH SET a.data = $data + ON CREATE SET a.data = $data + """ + self._conn.execute(query, {"id": node_id, "data": json_data}) + + def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ): + """ + Insert or Update edge. + Note: We explicitly ensure nodes exist before merging the edge to avoid errors, + although GraphGen generally creates nodes before edges. + """ + # Ensure source node exists + if not self.has_node(source_node_id): + self.upsert_node(source_node_id, {}) + # Ensure target node exists + if not self.has_node(target_node_id): + self.upsert_node(target_node_id, {}) + + json_data = json.dumps(edge_data, ensure_ascii=False) + query = """ + MATCH (a:Entity {id: $src}), (b:Entity {id: $dst}) + MERGE (a)-[e:Relation]->(b) + ON MATCH SET e.data = $data + ON CREATE SET e.data = $data + """ + self._conn.execute( + query, {"src": source_node_id, "dst": target_node_id, "data": json_data} + ) + + def delete_node(self, node_id: str): + # DETACH DELETE removes the node and all connected edges + query = "MATCH (a:Entity {id: $id}) DETACH DELETE a" + self._conn.execute(query, {"id": node_id}) + print(f"Node {node_id} deleted from KuzuDB.") + + def clear(self): + """Clear all data but keep schema (or drop tables).""" + self._conn.execute("MATCH (n) DETACH DELETE n") + print(f"Graph {self.namespace} cleared.") + + def reload(self): + """For databases that need reloading, KuzuDB auto-manages this.""" + + def drop(self): + """Completely remove the database folder.""" + if self.db_path and os.path.exists(self.db_path): + shutil.rmtree(self.db_path) + print(f"Dropped KuzuDB at {self.db_path}") diff --git a/graphgen/models/storage/networkx_storage.py b/graphgen/models/storage/graph/networkx_storage.py similarity index 85% rename from graphgen/models/storage/networkx_storage.py rename to graphgen/models/storage/graph/networkx_storage.py index 36bf1b5eb389eb6146ee4a035192d32b0b0e54e9..7fb73b79de95c72a1090802150d353e781cfcb46 100644 --- a/graphgen/models/storage/networkx_storage.py +++ b/graphgen/models/storage/graph/networkx_storage.py @@ -6,7 +6,6 @@ from typing import Any, Optional, Union, cast import networkx as nx from graphgen.bases.base_storage import BaseGraphStorage -from graphgen.utils import logger @dataclass @@ -19,11 +18,6 @@ class NetworkXStorage(BaseGraphStorage): @staticmethod def write_nx_graph(graph: nx.Graph, file_name): - logger.info( - "Writing graph with %d nodes, %d edges", - graph.number_of_nodes(), - graph.number_of_edges(), - ) nx.write_graphml(graph, file_name) @staticmethod @@ -82,12 +76,11 @@ class NetworkXStorage(BaseGraphStorage): self.working_dir, f"{self.namespace}.graphml" ) preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file) - if preloaded_graph is not None: - logger.info( - "Loaded graph from %s with %d nodes, %d edges", - self._graphml_xml_file, - preloaded_graph.number_of_nodes(), - preloaded_graph.number_of_edges(), + if preloaded_graph: + print( + f"Loaded graph from {self._graphml_xml_file} with " + f"{preloaded_graph.number_of_nodes()} nodes, " + f"{preloaded_graph.number_of_edges()} edges" ) self._graph = preloaded_graph or nx.Graph() @@ -133,7 +126,7 @@ class NetworkXStorage(BaseGraphStorage): if self._graph.has_node(node_id): self._graph.nodes[node_id].update(node_data) else: - logger.warning("Node %s not found in the graph for update.", node_id) + print(f"Node {node_id} not found in the graph for update.") def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] @@ -146,10 +139,8 @@ class NetworkXStorage(BaseGraphStorage): if self._graph.has_edge(source_node_id, target_node_id): self._graph.edges[(source_node_id, target_node_id)].update(edge_data) else: - logger.warning( - "Edge %s -> %s not found in the graph for update.", - source_node_id, - target_node_id, + print( + f"Edge {source_node_id} -> {target_node_id} not found in the graph for update." ) def delete_node(self, node_id: str): @@ -160,13 +151,19 @@ class NetworkXStorage(BaseGraphStorage): """ if self._graph.has_node(node_id): self._graph.remove_node(node_id) - logger.info("Node %s deleted from the graph.", node_id) + print(f"Node {node_id} deleted from the graph.") else: - logger.warning("Node %s not found in the graph for deletion.", node_id) + print(f"Node {node_id} not found in the graph for deletion.") def clear(self): """ Clear the graph by removing all nodes and edges. """ self._graph.clear() - logger.info("Graph %s cleared.", self.namespace) + print(f"Graph {self.namespace} cleared.") + + def reload(self): + """ + Reload the graph from the GraphML file. + """ + self.__post_init__() diff --git a/graphgen/models/storage/kv/__init__.py b/graphgen/models/storage/kv/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/graphgen/models/storage/json_storage.py b/graphgen/models/storage/kv/json_storage.py similarity index 53% rename from graphgen/models/storage/json_storage.py rename to graphgen/models/storage/kv/json_storage.py index 53962117fb17b24ad18b75de7ffa4a2ed2afc21c..aa7c6f42168f5aa802029dccac7b6861f03f4b36 100644 --- a/graphgen/models/storage/json_storage.py +++ b/graphgen/models/storage/kv/json_storage.py @@ -1,8 +1,8 @@ import os from dataclasses import dataclass -from graphgen.bases.base_storage import BaseKVStorage, BaseListStorage -from graphgen.utils import load_json, logger, write_json +from graphgen.bases.base_storage import BaseKVStorage +from graphgen.utils import load_json, write_json @dataclass @@ -12,7 +12,7 @@ class JsonKVStorage(BaseKVStorage): def __post_init__(self): self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json") self._data = load_json(self._file_name) or {} - logger.info("Load KV %s with %d data", self.namespace, len(self._data)) + print(f"Load KV {self.namespace} with {len(self._data)} data") @property def data(self): @@ -55,40 +55,6 @@ class JsonKVStorage(BaseKVStorage): if self._data: self._data.clear() - -@dataclass -class JsonListStorage(BaseListStorage): - working_dir: str = None - namespace: str = None - _data: list = None - - def __post_init__(self): - self._file_name = os.path.join(self.working_dir, f"{self.namespace}.json") - self._data = load_json(self._file_name) or [] - logger.info("Load List %s with %d data", self.namespace, len(self._data)) - - @property - def data(self): - return self._data - - def all_items(self) -> list: - return self._data - - def index_done_callback(self): - write_json(self._data, self._file_name) - - def get_by_index(self, index: int): - if index < 0 or index >= len(self._data): - return None - return self._data[index] - - def append(self, data): - self._data.append(data) - - def upsert(self, data: list): - left_data = [d for d in data if d not in self._data] - self._data.extend(left_data) - return left_data - - def drop(self): - self._data = [] + def reload(self): + self._data = load_json(self._file_name) or {} + print(f"Reload KV {self.namespace} with {len(self._data)} data") diff --git a/graphgen/models/storage/kv/rocksdb_storage.py b/graphgen/models/storage/kv/rocksdb_storage.py new file mode 100644 index 0000000000000000000000000000000000000000..45055b932ccf6e089ae7bd7d0e60207468fcf6f0 --- /dev/null +++ b/graphgen/models/storage/kv/rocksdb_storage.py @@ -0,0 +1,82 @@ +import os +from dataclasses import dataclass +from typing import Any, Dict, List, Set + +# rocksdict is a lightweight C wrapper around RocksDB for Python, pylint may not recognize it +# pylint: disable=no-name-in-module +from rocksdict import Rdict + +from graphgen.bases.base_storage import BaseKVStorage + + +@dataclass +class RocksDBKVStorage(BaseKVStorage): + _db: Rdict = None + _db_path: str = None + + def __post_init__(self): + self._db_path = os.path.join(self.working_dir, f"{self.namespace}.db") + self._db = Rdict(self._db_path) + print( + f"RocksDBKVStorage initialized for namespace '{self.namespace}' at '{self._db_path}'" + ) + + @property + def data(self): + return self._db + + def all_keys(self) -> List[str]: + return list(self._db.keys()) + + def index_done_callback(self): + self._db.flush() + print(f"RocksDB flushed for {self.namespace}") + + def get_by_id(self, id: str) -> Any: + return self._db.get(id, None) + + def get_by_ids(self, ids: List[str], fields: List[str] = None) -> List[Any]: + result = [] + for index in ids: + item = self._db.get(index, None) + if item is None: + result.append(None) + continue + + if fields is None: + result.append(item) + else: + result.append({k: v for k, v in item.items() if k in fields}) + return result + + def get_all(self) -> Dict[str, Dict]: + return dict(self._db) + + def filter_keys(self, data: List[str]) -> Set[str]: + return {s for s in data if s not in self._db} + + def upsert(self, data: Dict[str, Any]): + left_data = {} + for k, v in data.items(): + if k not in self._db: + left_data[k] = v + + if left_data: + for k, v in left_data.items(): + self._db[k] = v + # if left_data is very large, it is recommended to use self._db.write_batch() for optimization + + return left_data + + def drop(self): + self._db.close() + Rdict.destroy(self._db_path) + self._db = Rdict(self._db_path) + print(f"Dropped RocksDB {self.namespace}") + + def close(self): + if self._db: + self._db.close() + + def reload(self): + """For databases that need reloading, RocksDB auto-manages this.""" diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 97f4b3c8182a4a4a30949d120a2e2653194d4c1f..64c78af51a4545551283ee377094806d4680f600 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,9 +1,21 @@ -from .build_kg import build_kg -from .extract import extract_info -from .generate import generate_qas -from .init import init_llm -from .partition import partition_kg -from .quiz_and_judge import judge_statement, quiz -from .read import read_files +from .build_kg import BuildKGService +from .chunk import ChunkService +from .extract import ExtractService +from .generate import GenerateService +from .judge import JudgeService +from .partition import PartitionService +from .quiz import QuizService +from .read import read from .search import search_all -from .split import chunk_documents + +operators = { + "read": read, + "chunk": ChunkService, + "build_kg": BuildKGService, + "quiz": QuizService, + "judge": JudgeService, + "extract": ExtractService, + "search": search_all, + "partition": PartitionService, + "generate": GenerateService, +} diff --git a/graphgen/operators/build_kg/__init__.py b/graphgen/operators/build_kg/__init__.py index 18766fe6311faf02e6d3fd40f8cc2ef34bb2fb24..a8b22ce94433d739167b5a0325ab1f71bb5bed35 100644 --- a/graphgen/operators/build_kg/__init__.py +++ b/graphgen/operators/build_kg/__init__.py @@ -1 +1 @@ -from .build_kg import build_kg +from .build_kg_service import BuildKGService diff --git a/graphgen/operators/build_kg/build_kg.py b/graphgen/operators/build_kg/build_kg.py deleted file mode 100644 index a8a6146dbd4807dbbf5bb7d7114e12d5d136d33b..0000000000000000000000000000000000000000 --- a/graphgen/operators/build_kg/build_kg.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import List - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.bases.base_storage import BaseGraphStorage -from graphgen.bases.datatypes import Chunk -from graphgen.utils import logger - -from .build_mm_kg import build_mm_kg -from .build_text_kg import build_text_kg - - -async def build_kg( - llm_client: BaseLLMWrapper, - kg_instance: BaseGraphStorage, - chunks: List[Chunk], - progress_bar: gr.Progress = None, -): - """ - Build knowledge graph (KG) and merge into kg_instance - :param llm_client: Synthesizer LLM model to extract entities and relationships - :param kg_instance - :param chunks - :param anchor_type: get this type of information from chunks - :param progress_bar: Gradio progress bar to show the progress of the extraction - :return: - """ - - text_chunks = [chunk for chunk in chunks if chunk.type == "text"] - mm_chunks = [ - chunk - for chunk in chunks - if chunk.type in ("image", "video", "table", "formula") - ] - - if len(text_chunks) == 0: - logger.info("All text chunks are already in the storage") - else: - logger.info("[Text Entity and Relation Extraction] processing ...") - await build_text_kg( - llm_client=llm_client, - kg_instance=kg_instance, - chunks=text_chunks, - progress_bar=progress_bar, - ) - - if len(mm_chunks) == 0: - logger.info("All multi-modal chunks are already in the storage") - else: - logger.info("[Multi-modal Entity and Relation Extraction] processing ...") - await build_mm_kg( - llm_client=llm_client, - kg_instance=kg_instance, - chunks=mm_chunks, - progress_bar=progress_bar, - ) - - return kg_instance diff --git a/graphgen/operators/build_kg/build_kg_service.py b/graphgen/operators/build_kg/build_kg_service.py new file mode 100644 index 0000000000000000000000000000000000000000..ec3c7cc135030f42b79f013f2e80c616192c3e3b --- /dev/null +++ b/graphgen/operators/build_kg/build_kg_service.py @@ -0,0 +1,60 @@ +from typing import List + +import pandas as pd + +from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator +from graphgen.bases.datatypes import Chunk +from graphgen.common import init_llm, init_storage +from graphgen.utils import logger + +from .build_mm_kg import build_mm_kg +from .build_text_kg import build_text_kg + + +class BuildKGService(BaseOperator): + def __init__(self, working_dir: str = "cache", graph_backend: str = "kuzu"): + super().__init__(working_dir=working_dir, op_name="build_kg_service") + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + self.graph_storage: BaseGraphStorage = init_storage( + backend=graph_backend, working_dir=working_dir, namespace="graph" + ) + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + docs = batch.to_dict(orient="records") + docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs] + + # consume the chunks and build kg + self.build_kg(docs) + return pd.DataFrame([{"status": "kg_building_completed"}]) + + def build_kg(self, chunks: List[Chunk]) -> None: + """ + Build knowledge graph (KG) and merge into kg_instance + """ + text_chunks = [chunk for chunk in chunks if chunk.type == "text"] + mm_chunks = [ + chunk + for chunk in chunks + if chunk.type in ("image", "video", "table", "formula") + ] + + if len(text_chunks) == 0: + logger.info("All text chunks are already in the storage") + else: + logger.info("[Text Entity and Relation Extraction] processing ...") + build_text_kg( + llm_client=self.llm_client, + kg_instance=self.graph_storage, + chunks=text_chunks, + ) + if len(mm_chunks) == 0: + logger.info("All multi-modal chunks are already in the storage") + else: + logger.info("[Multi-modal Entity and Relation Extraction] processing ...") + build_mm_kg( + llm_client=self.llm_client, + kg_instance=self.graph_storage, + chunks=mm_chunks, + ) + + self.graph_storage.index_done_callback() diff --git a/graphgen/operators/build_kg/build_mm_kg.py b/graphgen/operators/build_kg/build_mm_kg.py index 624b10ada934efe6abac212b73d4a3c2652d98d3..ee0459eac3d3e4d16e9db9c66066abe5e62edd2a 100644 --- a/graphgen/operators/build_kg/build_mm_kg.py +++ b/graphgen/operators/build_kg/build_mm_kg.py @@ -1,8 +1,6 @@ from collections import defaultdict from typing import List -import gradio as gr - from graphgen.bases import BaseLLMWrapper from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Chunk @@ -10,28 +8,25 @@ from graphgen.models import MMKGBuilder from graphgen.utils import run_concurrent -async def build_mm_kg( +def build_mm_kg( llm_client: BaseLLMWrapper, kg_instance: BaseGraphStorage, chunks: List[Chunk], - progress_bar: gr.Progress = None, ): """ Build multi-modal KG and merge into kg_instance :param llm_client: Synthesizer LLM model to extract entities and relationships :param kg_instance :param chunks - :param progress_bar: Gradio progress bar to show the progress of the extraction :return: """ mm_builder = MMKGBuilder(llm_client=llm_client) - results = await run_concurrent( + results = run_concurrent( mm_builder.extract, chunks, desc="[2/4] Extracting entities and relationships from multi-modal chunks", unit="chunk", - progress_bar=progress_bar, ) nodes = defaultdict(list) @@ -42,16 +37,14 @@ async def build_mm_kg( for k, v in e.items(): edges[tuple(sorted(k))].extend(v) - await run_concurrent( + run_concurrent( lambda kv: mm_builder.merge_nodes(kv, kg_instance=kg_instance), list(nodes.items()), desc="Inserting entities into storage", ) - await run_concurrent( + run_concurrent( lambda kv: mm_builder.merge_edges(kv, kg_instance=kg_instance), list(edges.items()), desc="Inserting relationships into storage", ) - - return kg_instance diff --git a/graphgen/operators/build_kg/build_text_kg.py b/graphgen/operators/build_kg/build_text_kg.py index 3c75f022513366d532be884fd66579c8778e774b..1b5a8762cfd63e71a3b66a3222c89d63675272d0 100644 --- a/graphgen/operators/build_kg/build_text_kg.py +++ b/graphgen/operators/build_kg/build_text_kg.py @@ -1,8 +1,6 @@ from collections import defaultdict from typing import List -import gradio as gr - from graphgen.bases import BaseLLMWrapper from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Chunk @@ -10,28 +8,25 @@ from graphgen.models import LightRAGKGBuilder from graphgen.utils import run_concurrent -async def build_text_kg( +def build_text_kg( llm_client: BaseLLMWrapper, kg_instance: BaseGraphStorage, chunks: List[Chunk], - progress_bar: gr.Progress = None, ): """ :param llm_client: Synthesizer LLM model to extract entities and relationships :param kg_instance :param chunks - :param progress_bar: Gradio progress bar to show the progress of the extraction :return: """ kg_builder = LightRAGKGBuilder(llm_client=llm_client, max_loop=3) - results = await run_concurrent( + results = run_concurrent( kg_builder.extract, chunks, desc="[2/4]Extracting entities and relationships from chunks", unit="chunk", - progress_bar=progress_bar, ) nodes = defaultdict(list) @@ -42,16 +37,14 @@ async def build_text_kg( for k, v in e.items(): edges[tuple(sorted(k))].extend(v) - await run_concurrent( + run_concurrent( lambda kv: kg_builder.merge_nodes(kv, kg_instance=kg_instance), list(nodes.items()), desc="Inserting entities into storage", ) - await run_concurrent( + run_concurrent( lambda kv: kg_builder.merge_edges(kv, kg_instance=kg_instance), list(edges.items()), desc="Inserting relationships into storage", ) - - return kg_instance diff --git a/graphgen/operators/chunk/__init__.py b/graphgen/operators/chunk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2f116f7f926a4bbc77d156ef009c7c35347a6bf --- /dev/null +++ b/graphgen/operators/chunk/__init__.py @@ -0,0 +1 @@ +from .chunk_service import ChunkService diff --git a/graphgen/operators/chunk/chunk_service.py b/graphgen/operators/chunk/chunk_service.py new file mode 100644 index 0000000000000000000000000000000000000000..102c74fd7a64aac18acfba3cdb13b0172d465fe5 --- /dev/null +++ b/graphgen/operators/chunk/chunk_service.py @@ -0,0 +1,103 @@ +import os +from functools import lru_cache +from typing import Union + +import pandas as pd + +from graphgen.bases import BaseOperator +from graphgen.common import init_storage +from graphgen.models import ( + ChineseRecursiveTextSplitter, + RecursiveCharacterSplitter, + Tokenizer, +) +from graphgen.utils import compute_content_hash, detect_main_language + +_MAPPING = { + "en": RecursiveCharacterSplitter, + "zh": ChineseRecursiveTextSplitter, +} + +SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter] + + +@lru_cache(maxsize=None) +def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT: + cls = _MAPPING[language] + kwargs = dict(frozen_kwargs) + return cls(**kwargs) + + +def split_chunks(text: str, language: str = "en", **kwargs) -> list: + if language not in _MAPPING: + raise ValueError( + f"Unsupported language: {language}. " + f"Supported languages are: {list(_MAPPING.keys())}" + ) + frozen_kwargs = frozenset( + (k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items() + ) + splitter = _get_splitter(language, frozen_kwargs) + return splitter.split_text(text) + + +class ChunkService(BaseOperator): + def __init__( + self, working_dir: str = "cache", kv_backend: str = "rocksdb", **chunk_kwargs + ): + super().__init__(working_dir=working_dir, op_name="chunk_service") + tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base") + self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model) + self.chunk_storage = init_storage( + backend=kv_backend, + working_dir=working_dir, + namespace="chunk", + ) + self.chunk_kwargs = chunk_kwargs + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + docs = batch.to_dict(orient="records") + return pd.DataFrame(self.chunk_documents(docs)) + + def chunk_documents(self, new_docs: list) -> list: + chunks = [] + for doc in new_docs: + doc_id = doc.get("_doc_id") + doc_type = doc.get("type") + + if doc_type == "text": + doc_language = detect_main_language(doc["content"]) + text_chunks = split_chunks( + doc["content"], + language=doc_language, + **self.chunk_kwargs, + ) + + chunks.extend( + [ + { + "_chunk_id": compute_content_hash( + chunk_text, prefix="chunk-" + ), + "content": chunk_text, + "type": "text", + "_doc_id": doc_id, + "length": len(self.tokenizer_instance.encode(chunk_text)) + if self.tokenizer_instance + else len(chunk_text), + "language": doc_language, + } + for chunk_text in text_chunks + ] + ) + else: + # other types of documents(images, sequences) are not chunked + chunks.append( + { + "_chunk_id": doc_id.replace("doc-", f"{doc_type}-"), + **doc, + } + ) + self.chunk_storage.upsert({chunk["_chunk_id"]: chunk for chunk in chunks}) + self.chunk_storage.index_done_callback() + return chunks diff --git a/graphgen/operators/evaluate/__init__.py b/graphgen/operators/evaluate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/graphgen/evaluate.py b/graphgen/operators/evaluate/evaluate.py similarity index 97% rename from graphgen/evaluate.py rename to graphgen/operators/evaluate/evaluate.py index d1e2413b9ac39050095dcea632c5c17607d51e51..fdbfbf82580cb8670950b04b17c44b7cbcee6100 100644 --- a/graphgen/evaluate.py +++ b/graphgen/operators/evaluate/evaluate.py @@ -9,9 +9,13 @@ import pandas as pd from dotenv import load_dotenv from graphgen.bases.datatypes import QAPair - -from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator -from .utils import logger, set_logger +from graphgen.models import ( + LengthEvaluator, + MTLDEvaluator, + RewardEvaluator, + UniEvaluator, +) +from graphgen.utils import logger, set_logger sys_path = os.path.abspath(os.path.dirname(__file__)) set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log")) diff --git a/graphgen/operators/extract/__init__.py b/graphgen/operators/extract/__init__.py index ec576cb652eaa001616b6d212bf52896949eedaf..6c7c2b94cb764957dd5aaf76179bf206a2dc3d95 100644 --- a/graphgen/operators/extract/__init__.py +++ b/graphgen/operators/extract/__init__.py @@ -1 +1 @@ -from .extract_info import extract_info +from .extract_service import ExtractService diff --git a/graphgen/operators/extract/extract_info.py b/graphgen/operators/extract/extract_info.py deleted file mode 100644 index 8e65f1b2908c5050b1707211101e6bded4b78cfa..0000000000000000000000000000000000000000 --- a/graphgen/operators/extract/extract_info.py +++ /dev/null @@ -1,47 +0,0 @@ -import json - -import gradio as gr - -from graphgen.bases import BaseKVStorage, BaseLLMWrapper -from graphgen.models.extractor import SchemaGuidedExtractor -from graphgen.utils import logger, run_concurrent - - -async def extract_info( - llm_client: BaseLLMWrapper, - chunk_storage: BaseKVStorage, - extract_config: dict, - progress_bar: gr.Progress = None, -): - """ - Extract information from chunks - :param llm_client: LLM client - :param chunk_storage: storage for chunks - :param extract_config - :param progress_bar - :return: extracted information - """ - - method = extract_config.get("method") - if method == "schema_guided": - schema_file = extract_config.get("schema_file") - with open(schema_file, "r", encoding="utf-8") as f: - schema = json.load(f) - extractor = SchemaGuidedExtractor(llm_client, schema) - else: - raise ValueError(f"Unsupported extraction method: {method}") - - chunks = chunk_storage.get_all() - chunks = [{k: v} for k, v in chunks.items()] - logger.info("Start extracting information from %d chunks", len(chunks)) - - results = await run_concurrent( - extractor.extract, - chunks, - desc="Extracting information", - unit="chunk", - progress_bar=progress_bar, - ) - - results = await extractor.merge_extractions(results) - return results diff --git a/graphgen/operators/extract/extract_service.py b/graphgen/operators/extract/extract_service.py new file mode 100644 index 0000000000000000000000000000000000000000..33987fcb24171fce139cb25d6f750df83f4cb007 --- /dev/null +++ b/graphgen/operators/extract/extract_service.py @@ -0,0 +1,45 @@ +import json + +import pandas as pd + +from graphgen.bases import BaseLLMWrapper, BaseOperator +from graphgen.common import init_llm +from graphgen.models.extractor import SchemaGuidedExtractor +from graphgen.utils import logger, run_concurrent + + +class ExtractService(BaseOperator): + def __init__(self, working_dir: str = "cache", **extract_kwargs): + super().__init__(working_dir=working_dir, op_name="extract_service") + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + self.extract_kwargs = extract_kwargs + self.method = self.extract_kwargs.get("method") + if self.method == "schema_guided": + schema_file = self.extract_kwargs.get("schema_path") + with open(schema_file, "r", encoding="utf-8") as f: + schema = json.load(f) + self.extractor = SchemaGuidedExtractor(self.llm_client, schema) + else: + raise ValueError(f"Unsupported extraction method: {self.method}") + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + items = batch.to_dict(orient="records") + return pd.DataFrame(self.extract(items)) + + def extract(self, items: list[dict]) -> list[dict]: + + logger.info("Start extracting information from %d items", len(items)) + + results = run_concurrent( + self.extractor.extract, + items, + desc="Extracting information", + unit="item", + ) + results = self.extractor.merge_extractions(results) + + results = [ + {"_extract_id": key, "extracted_data": value} + for key, value in results.items() + ] + return results diff --git a/graphgen/operators/generate/__init__.py b/graphgen/operators/generate/__init__.py index 035eca3608197138d5c9f1ac1614792bd4a8754b..04057ce64eb95beaa8039825f182f83513588d06 100644 --- a/graphgen/operators/generate/__init__.py +++ b/graphgen/operators/generate/__init__.py @@ -1 +1 @@ -from .generate_qas import generate_qas +from .generate_service import GenerateService diff --git a/graphgen/operators/generate/generate_qas.py b/graphgen/operators/generate/generate_qas.py deleted file mode 100644 index 86dbb9c95eb73375fa6f628406ef87adeba26bf1..0000000000000000000000000000000000000000 --- a/graphgen/operators/generate/generate_qas.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Any - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.models import ( - AggregatedGenerator, - AtomicGenerator, - CoTGenerator, - MultiHopGenerator, - VQAGenerator, -) -from graphgen.utils import logger, run_concurrent - - -async def generate_qas( - llm_client: BaseLLMWrapper, - batches: list[ - tuple[ - list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] - ] - ], - generation_config: dict, - progress_bar: gr.Progress = None, -) -> list[dict[str, Any]]: - """ - Generate question-answer pairs based on nodes and edges. - :param llm_client: LLM client - :param batches - :param generation_config - :param progress_bar - :return: QA pairs - """ - method = generation_config["method"] - logger.info("[Generation] mode: %s, batches: %d", method, len(batches)) - - if method == "atomic": - generator = AtomicGenerator(llm_client) - elif method == "aggregated": - generator = AggregatedGenerator(llm_client) - elif method == "multi_hop": - generator = MultiHopGenerator(llm_client) - elif method == "cot": - generator = CoTGenerator(llm_client) - elif method in ["vqa"]: - generator = VQAGenerator(llm_client) - else: - raise ValueError(f"Unsupported generation mode: {method}") - - results = await run_concurrent( - generator.generate, - batches, - desc="[4/4]Generating QAs", - unit="batch", - progress_bar=progress_bar, - ) - - # format - data_format = generation_config["data_format"] - logger.info("Output data format: %s", data_format) - - results = generator.format_generation_results( - results, output_data_format=data_format - ) - - return results diff --git a/graphgen/operators/generate/generate_service.py b/graphgen/operators/generate/generate_service.py new file mode 100644 index 0000000000000000000000000000000000000000..1ae2f067f83b15bb5e1b71b91fa732552dd6227e --- /dev/null +++ b/graphgen/operators/generate/generate_service.py @@ -0,0 +1,68 @@ +import pandas as pd + +from graphgen.bases import BaseLLMWrapper, BaseOperator +from graphgen.common import init_llm +from graphgen.models import ( + AggregatedGenerator, + AtomicGenerator, + CoTGenerator, + MultiHopGenerator, + VQAGenerator, +) +from graphgen.utils import logger, run_concurrent + + +class GenerateService(BaseOperator): + """ + Generate question-answer pairs based on nodes and edges. + """ + + def __init__( + self, + working_dir: str = "cache", + method: str = "aggregated", + data_format: str = "ChatML", + ): + super().__init__(working_dir=working_dir, op_name="generate_service") + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + + self.method = method + self.data_format = data_format + + if self.method == "atomic": + self.generator = AtomicGenerator(self.llm_client) + elif self.method == "aggregated": + self.generator = AggregatedGenerator(self.llm_client) + elif self.method == "multi_hop": + self.generator = MultiHopGenerator(self.llm_client) + elif self.method == "cot": + self.generator = CoTGenerator(self.llm_client) + elif self.method in ["vqa"]: + self.generator = VQAGenerator(self.llm_client) + else: + raise ValueError(f"Unsupported generation mode: {method}") + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + items = batch.to_dict(orient="records") + return pd.DataFrame(self.generate(items)) + + def generate(self, items: list[dict]) -> list[dict]: + """ + Generate question-answer pairs based on nodes and edges. + :param items + :return: QA pairs + """ + logger.info("[Generation] mode: %s, batches: %d", self.method, len(items)) + items = [(item["nodes"], item["edges"]) for item in items] + results = run_concurrent( + self.generator.generate, + items, + desc="[4/4]Generating QAs", + unit="batch", + ) + + results = self.generator.format_generation_results( + results, output_data_format=self.data_format + ) + + return results diff --git a/graphgen/operators/init/__init__.py b/graphgen/operators/init/__init__.py deleted file mode 100644 index ec6044419e4ee1811a702a76aaac7eccb6771bf9..0000000000000000000000000000000000000000 --- a/graphgen/operators/init/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .init_llm import init_llm diff --git a/graphgen/operators/init/init_llm.py b/graphgen/operators/init/init_llm.py deleted file mode 100644 index e294d2c31564765f1e0a7f201a19e3b787e07dd1..0000000000000000000000000000000000000000 --- a/graphgen/operators/init/init_llm.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -from typing import Any, Dict, Optional - -from graphgen.bases import BaseLLMWrapper -from graphgen.models import Tokenizer - - -class LLMFactory: - """ - A factory class to create LLM wrapper instances based on the specified backend. - Supported backends include: - - http_api: HTTPClient - - openai_api: OpenAIClient - - ollama_api: OllamaClient - - huggingface: HuggingFaceWrapper - - sglang: SGLangWrapper - """ - - @staticmethod - def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper: - # add tokenizer - tokenizer: Tokenizer = Tokenizer( - os.environ.get("TOKENIZER_MODEL", "cl100k_base"), - ) - config["tokenizer"] = tokenizer - if backend == "http_api": - from graphgen.models.llm.api.http_client import HTTPClient - - return HTTPClient(**config) - if backend in ("openai_api", "azure_openai_api"): - from graphgen.models.llm.api.openai_client import OpenAIClient - # pass in concrete backend to the OpenAIClient so that internally we can distinguish - # between OpenAI and Azure OpenAI - return OpenAIClient(**config, backend=backend) - if backend == "ollama_api": - from graphgen.models.llm.api.ollama_client import OllamaClient - - return OllamaClient(**config) - if backend == "huggingface": - from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper - - return HuggingFaceWrapper(**config) - if backend == "sglang": - from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper - - return SGLangWrapper(**config) - - # if backend == "vllm": - # from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper - # - # return VLLMWrapper(**config) - - raise NotImplementedError(f"Backend {backend} is not implemented yet.") - - -def _load_env_group(prefix: str) -> Dict[str, Any]: - """ - Collect environment variables with the given prefix into a dictionary, - stripping the prefix from the keys. - """ - return { - k[len(prefix) :].lower(): v - for k, v in os.environ.items() - if k.startswith(prefix) - } - - -def init_llm(model_type: str) -> Optional[BaseLLMWrapper]: - if model_type == "synthesizer": - prefix = "SYNTHESIZER_" - elif model_type == "trainee": - prefix = "TRAINEE_" - else: - raise NotImplementedError(f"Model type {model_type} is not implemented yet.") - config = _load_env_group(prefix) - # if config is empty, return None - if not config: - return None - backend = config.pop("backend") - llm_wrapper = LLMFactory.create_llm_wrapper(backend, config) - return llm_wrapper diff --git a/graphgen/operators/judge/__init__.py b/graphgen/operators/judge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32ccf5c21bd13f9dc309f0cc6447f6a69e917522 --- /dev/null +++ b/graphgen/operators/judge/__init__.py @@ -0,0 +1 @@ +from .judge_service import JudgeService diff --git a/graphgen/operators/judge/judge_service.py b/graphgen/operators/judge/judge_service.py new file mode 100644 index 0000000000000000000000000000000000000000..c7693aec98d02146aa81701aec611d30bf58e96f --- /dev/null +++ b/graphgen/operators/judge/judge_service.py @@ -0,0 +1,70 @@ +import math + +import pandas as pd + +from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator +from graphgen.common import init_llm, init_storage +from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT +from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy + + +class JudgeService(BaseOperator): + """Service for judging graph edges and nodes using a trainee LLM.""" + + def __init__(self, working_dir: str = "cache", graph_backend: str = "kuzu"): + super().__init__(working_dir=working_dir, op_name="judge_service") + self.llm_client: BaseLLMWrapper = init_llm("trainee") + self.graph_storage: BaseGraphStorage = init_storage( + backend=graph_backend, + working_dir=working_dir, + namespace="graph", + ) + + def process(self, batch: pd.DataFrame) -> pd.DataFrame: + items = batch.to_dict(orient="records") + self.graph_storage.reload() + self.judge(items) + return pd.DataFrame([{"status": "judging_completed"}]) + + async def _process_single_judge(self, item: dict) -> dict: + description = item["description"] + try: + judgement = await self.llm_client.generate_topk_per_token( + STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) + ) + top_candidates = judgement[0].top_candidates + gt = item.get("ground_truth", "yes") + loss = yes_no_loss_entropy([top_candidates], [gt]) + logger.debug("Description: %s Loss: %s", description, loss) + item["loss"] = loss + except Exception as e: # pylint: disable=broad-except + logger.error("Error in judging description: %s", e) + logger.info("Use default loss 0.1") + item["loss"] = -math.log(0.1) + return item + + def judge(self, items: list[dict]) -> None: + """ + Judge the description in the item and compute the loss. + """ + results = run_concurrent( + self._process_single_judge, + items, + desc="Judging descriptions", + unit="description", + ) + # Update the graph storage with the computed losses + for item in results: + index = item["index"] + loss = item["loss"] + if isinstance(index, str): + node_id = index + node_data = self.graph_storage.get_node(node_id) + node_data["loss"] = loss + self.graph_storage.update_node(node_id, node_data) + elif isinstance(index, tuple): + edge_source, edge_target = index + edge_data = self.graph_storage.get_edge(edge_source, edge_target) + edge_data["loss"] = loss + self.graph_storage.update_edge(edge_source, edge_target, edge_data) + self.graph_storage.index_done_callback() diff --git a/graphgen/operators/partition/__init__.py b/graphgen/operators/partition/__init__.py index 21f934b38b328b9fb804c5e6969bd85662c99a0e..8d586b95c6ef2c9c7fda862b8648c0c3766168c5 100644 --- a/graphgen/operators/partition/__init__.py +++ b/graphgen/operators/partition/__init__.py @@ -1 +1 @@ -from .partition_kg import partition_kg +from .partition_service import PartitionService diff --git a/graphgen/operators/partition/partition_kg.py b/graphgen/operators/partition/partition_kg.py deleted file mode 100644 index 4c4fdaa1b8ae421aca31ea517c6b63684d722f35..0000000000000000000000000000000000000000 --- a/graphgen/operators/partition/partition_kg.py +++ /dev/null @@ -1,114 +0,0 @@ -from typing import Any - -from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseTokenizer -from graphgen.models import ( - AnchorBFSPartitioner, - BFSPartitioner, - DFSPartitioner, - ECEPartitioner, - LeidenPartitioner, -) -from graphgen.utils import logger - -from .pre_tokenize import pre_tokenize - - -async def partition_kg( - kg_instance: BaseGraphStorage, - chunk_storage: BaseKVStorage, - tokenizer: Any = BaseTokenizer, - partition_config: dict = None, -) -> list[ - tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]] -]: - method = partition_config["method"] - method_params = partition_config["method_params"] - if method == "bfs": - logger.info("Partitioning knowledge graph using BFS method.") - partitioner = BFSPartitioner() - elif method == "dfs": - logger.info("Partitioning knowledge graph using DFS method.") - partitioner = DFSPartitioner() - elif method == "ece": - logger.info("Partitioning knowledge graph using ECE method.") - # TODO: before ECE partitioning, we need to: - # 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random - # 2. pre-tokenize nodes and edges to get the token length - edges = kg_instance.get_all_edges() - nodes = kg_instance.get_all_nodes() - await pre_tokenize(kg_instance, tokenizer, edges, nodes) - partitioner = ECEPartitioner() - elif method == "leiden": - logger.info("Partitioning knowledge graph using Leiden method.") - partitioner = LeidenPartitioner() - elif method == "anchor_bfs": - logger.info("Partitioning knowledge graph using Anchor BFS method.") - partitioner = AnchorBFSPartitioner( - anchor_type=method_params.get("anchor_type"), - anchor_ids=set(method_params.get("anchor_ids", [])) - if method_params.get("anchor_ids") - else None, - ) - else: - raise ValueError(f"Unsupported partition method: {method}") - - communities = await partitioner.partition(g=kg_instance, **method_params) - logger.info("Partitioned the graph into %d communities.", len(communities)) - batches = await partitioner.community2batch(communities, g=kg_instance) - - batches = await attach_additional_data_to_node(batches, chunk_storage) - return batches - - -async def attach_additional_data_to_node( - batches: list[ - tuple[ - list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] - ] - ], - chunk_storage: BaseKVStorage, -) -> list[ - tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]] -]: - """ - Attach additional data from chunk_storage to nodes in the batches. - :param batches: - :param chunk_storage: - :return: - """ - for batch in batches: - for node_id, node_data in batch[0]: - await _attach_by_type(node_id, node_data, chunk_storage) - return batches - - -async def _attach_by_type( - node_id: str, - node_data: dict, - chunk_storage: BaseKVStorage, -) -> None: - """ - Attach additional data to the node based on its entity type. - """ - entity_type = (node_data.get("entity_type") or "").lower() - if not entity_type: - return - - source_ids = [ - sid.strip() - for sid in node_data.get("source_id", "").split("") - if sid.strip() - ] - - # Handle images - if "image" in entity_type: - image_chunks = [ - data - for sid in source_ids - if "image" in sid.lower() and (data := chunk_storage.get_by_id(sid)) - ] - if image_chunks: - # The generator expects a dictionary with an 'img_path' key, not a list of captions. - # We'll use the first image chunk found for this node. - node_data["images"] = image_chunks[0] - logger.debug("Attached image data to node %s", node_id) diff --git a/graphgen/operators/partition/partition_service.py b/graphgen/operators/partition/partition_service.py new file mode 100644 index 0000000000000000000000000000000000000000..2289fec67e12f8401d1a10cebc3247f58c9d9ffd --- /dev/null +++ b/graphgen/operators/partition/partition_service.py @@ -0,0 +1,163 @@ +import json +import os +from typing import Iterable + +import pandas as pd + +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseOperator, BaseTokenizer +from graphgen.common import init_storage +from graphgen.models import ( + AnchorBFSPartitioner, + BFSPartitioner, + DFSPartitioner, + ECEPartitioner, + LeidenPartitioner, + Tokenizer, +) +from graphgen.utils import logger + + +class PartitionService(BaseOperator): + def __init__( + self, + working_dir: str = "cache", + graph_backend: str = "kuzu", + kv_backend: str = "rocksdb", + **partition_kwargs, + ): + super().__init__(working_dir=working_dir, op_name="partition_service") + self.kg_instance: BaseGraphStorage = init_storage( + backend=graph_backend, + working_dir=working_dir, + namespace="graph", + ) + self.chunk_storage: BaseKVStorage = init_storage( + backend=kv_backend, + working_dir=working_dir, + namespace="chunk", + ) + tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base") + self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model) + self.partition_kwargs = partition_kwargs + + def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: + # this operator does not consume any batch data + # but for compatibility we keep the interface + _ = batch.to_dict(orient="records") + self.kg_instance.reload() + self.chunk_storage.reload() + + yield from self.partition() + + def partition(self) -> Iterable[pd.DataFrame]: + method = self.partition_kwargs["method"] + method_params = self.partition_kwargs["method_params"] + if method == "bfs": + logger.info("Partitioning knowledge graph using BFS method.") + partitioner = BFSPartitioner() + elif method == "dfs": + logger.info("Partitioning knowledge graph using DFS method.") + partitioner = DFSPartitioner() + elif method == "ece": + logger.info("Partitioning knowledge graph using ECE method.") + # TODO: before ECE partitioning, we need to: + # 1. 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random + # 2. pre-tokenize nodes and edges to get the token length + self._pre_tokenize() + partitioner = ECEPartitioner() + elif method == "leiden": + logger.info("Partitioning knowledge graph using Leiden method.") + partitioner = LeidenPartitioner() + elif method == "anchor_bfs": + logger.info("Partitioning knowledge graph using Anchor BFS method.") + partitioner = AnchorBFSPartitioner( + anchor_type=method_params.get("anchor_type"), + anchor_ids=set(method_params.get("anchor_ids", [])) + if method_params.get("anchor_ids") + else None, + ) + else: + raise ValueError(f"Unsupported partition method: {method}") + + communities = partitioner.partition(g=self.kg_instance, **method_params) + + for community in communities: + batch = partitioner.community2batch(community, g=self.kg_instance) + batch = self._attach_additional_data_to_node(batch) + + yield pd.DataFrame( + { + "nodes": [batch[0]], + "edges": [batch[1]], + } + ) + + def _pre_tokenize(self) -> None: + """Pre-tokenize all nodes and edges to add token length information.""" + logger.info("Starting pre-tokenization of nodes and edges...") + + nodes = self.kg_instance.get_all_nodes() + edges = self.kg_instance.get_all_edges() + + # Process nodes + for node_id, node_data in nodes: + if "length" not in node_data: + try: + description = node_data.get("description", "") + tokens = self.tokenizer_instance.encode(description) + node_data["length"] = len(tokens) + self.kg_instance.update_node(node_id, node_data) + except Exception as e: + logger.warning("Failed to tokenize node %s: %s", node_id, e) + node_data["length"] = 0 + + # Process edges + for u, v, edge_data in edges: + if "length" not in edge_data: + try: + description = edge_data.get("description", "") + tokens = self.tokenizer_instance.encode(description) + edge_data["length"] = len(tokens) + self.kg_instance.update_edge(u, v, edge_data) + except Exception as e: + logger.warning("Failed to tokenize edge %s-%s: %s", u, v, e) + edge_data["length"] = 0 + + # Persist changes + self.kg_instance.index_done_callback() + logger.info("Pre-tokenization completed.") + + def _attach_additional_data_to_node(self, batch: tuple) -> tuple: + """ + Attach additional data from chunk_storage to nodes in the batch. + :param batch: tuple of (nodes_data, edges_data) + :return: updated batch with additional data attached to nodes + """ + nodes_data, edges_data = batch + + for node_id, node_data in nodes_data: + entity_type = (node_data.get("entity_type") or "").lower() + if not entity_type: + continue + + source_ids = [ + sid.strip() + for sid in node_data.get("source_id", "").split("") + if sid.strip() + ] + + # Handle images + if "image" in entity_type: + image_chunks = [ + data + for sid in source_ids + if "image" in sid.lower() + and (data := self.chunk_storage.get_by_id(sid)) + ] + if image_chunks: + # The generator expects a dictionary with an 'img_path' key, not a list of captions. + # We'll use the first image chunk found for this node. + node_data["image_data"] = json.loads(image_chunks[0]["content"]) + logger.debug("Attached image data to node %s", node_id) + + return nodes_data, edges_data diff --git a/graphgen/operators/partition/pre_tokenize.py b/graphgen/operators/partition/pre_tokenize.py deleted file mode 100644 index 83e99060b5beed2b81ffa2806b82de9648394876..0000000000000000000000000000000000000000 --- a/graphgen/operators/partition/pre_tokenize.py +++ /dev/null @@ -1,55 +0,0 @@ -import asyncio -from typing import List, Tuple - -import gradio as gr - -from graphgen.bases import BaseGraphStorage, BaseTokenizer -from graphgen.utils import run_concurrent - - -async def pre_tokenize( - graph_storage: BaseGraphStorage, - tokenizer: BaseTokenizer, - edges: List[Tuple], - nodes: List[Tuple], - progress_bar: gr.Progress = None, - max_concurrent: int = 1000, -) -> Tuple[List, List]: - """为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。""" - sem = asyncio.Semaphore(max_concurrent) - - async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple: - async with sem: - data = obj[1] if is_node else obj[2] - if "length" not in data: - loop = asyncio.get_event_loop() - data["length"] = len( - await loop.run_in_executor( - None, tokenizer.encode, data["description"] - ) - ) - if is_node: - graph_storage.update_node(obj[0], obj[1]) - else: - graph_storage.update_edge(obj[0], obj[1], obj[2]) - return obj - - new_edges, new_nodes = await asyncio.gather( - run_concurrent( - lambda e: _patch_and_write(e, is_node=False), - edges, - desc="Pre-tokenizing edges", - unit="edge", - progress_bar=progress_bar, - ), - run_concurrent( - lambda n: _patch_and_write(n, is_node=True), - nodes, - desc="Pre-tokenizing nodes", - unit="node", - progress_bar=progress_bar, - ), - ) - - graph_storage.index_done_callback() - return new_edges, new_nodes diff --git a/graphgen/operators/quiz/__init__.py b/graphgen/operators/quiz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a931f4b9dd95e5c86879116de8663958f5064ef --- /dev/null +++ b/graphgen/operators/quiz/__init__.py @@ -0,0 +1 @@ +from .quiz_service import QuizService diff --git a/graphgen/operators/quiz/quiz_service.py b/graphgen/operators/quiz/quiz_service.py new file mode 100644 index 0000000000000000000000000000000000000000..a6aeb7be5480fd835e0870c1173693625a2a5da6 --- /dev/null +++ b/graphgen/operators/quiz/quiz_service.py @@ -0,0 +1,114 @@ +from collections.abc import Iterable + +import pandas as pd + +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper, BaseOperator +from graphgen.common import init_llm, init_storage +from graphgen.models import QuizGenerator +from graphgen.utils import compute_dict_hash, logger, run_concurrent + + +class QuizService(BaseOperator): + def __init__( + self, + working_dir: str = "cache", + graph_backend: str = "kuzu", + kv_backend: str = "rocksdb", + quiz_samples: int = 1, + concurrency_limit: int = 200, + ): + super().__init__(working_dir=working_dir, op_name="quiz_service") + self.quiz_samples = quiz_samples + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + self.graph_storage: BaseGraphStorage = init_storage( + backend=graph_backend, working_dir=working_dir, namespace="graph" + ) + # { _quiz_id: { "description": str, "quizzes": List[Tuple[str, str]] } } + self.quiz_storage: BaseKVStorage = init_storage( + backend=kv_backend, working_dir=working_dir, namespace="quiz" + ) + self.generator = QuizGenerator(self.llm_client) + self.concurrency_limit = concurrency_limit + + def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: + # this operator does not consume any batch data + # but for compatibility we keep the interface + _ = batch.to_dict(orient="records") + self.graph_storage.reload() + yield from self.quiz() + + async def _process_single_quiz(self, item: tuple) -> dict | None: + # if quiz in quiz_storage exists already, directly get it + index, desc = item + _quiz_id = compute_dict_hash({"index": index, "description": desc}) + if self.quiz_storage.get_by_id(_quiz_id): + return None + + tasks = [] + for i in range(self.quiz_samples): + if i > 0: + tasks.append((desc, "TEMPLATE", "yes")) + tasks.append((desc, "ANTI_TEMPLATE", "no")) + try: + quizzes = [] + for d, template_type, gt in tasks: + prompt = self.generator.build_prompt_for_description(d, template_type) + new_description = await self.llm_client.generate_answer( + prompt, temperature=1 + ) + rephrased_text = self.generator.parse_rephrased_text(new_description) + quizzes.append((rephrased_text, gt)) + return { + "_quiz_id": _quiz_id, + "description": desc, + "index": index, + "quizzes": quizzes, + } + except Exception as e: + logger.error("Error when quizzing description %s: %s", item, e) + return None + + def quiz(self) -> Iterable[pd.DataFrame]: + """ + Get all nodes and edges and quiz their descriptions using QuizGenerator. + """ + edges = self.graph_storage.get_all_edges() + nodes = self.graph_storage.get_all_nodes() + + items = [] + + for edge in edges: + edge_data = edge[2] + desc = edge_data["description"] + items.append(((edge[0], edge[1]), desc)) + + for node in nodes: + node_data = node[1] + desc = node_data["description"] + items.append((node[0], desc)) + + logger.info("Total descriptions to quiz: %d", len(items)) + + for i in range(0, len(items), self.concurrency_limit): + batch_items = items[i : i + self.concurrency_limit] + batch_results = run_concurrent( + self._process_single_quiz, + batch_items, + desc=f"Quizzing descriptions ({i} / {i + len(batch_items)})", + unit="description", + ) + + final_results = [] + for new_result in batch_results: + if new_result: + self.quiz_storage.upsert( + { + new_result["_quiz_id"]: { + "description": new_result["description"], + "quizzes": new_result["quizzes"], + } + } + ) + final_results.append(new_result) + self.quiz_storage.index_done_callback() + yield pd.DataFrame(final_results) diff --git a/graphgen/operators/quiz_and_judge/__init__.py b/graphgen/operators/quiz_and_judge/__init__.py deleted file mode 100644 index cb73251a9c7c91441f08d53634a901f0ae623949..0000000000000000000000000000000000000000 --- a/graphgen/operators/quiz_and_judge/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .judge import judge_statement -from .quiz import quiz diff --git a/graphgen/operators/quiz_and_judge/judge.py b/graphgen/operators/quiz_and_judge/judge.py deleted file mode 100644 index b5e35eb9bf032c7dfa21da10b7525b5061a08dd4..0000000000000000000000000000000000000000 --- a/graphgen/operators/quiz_and_judge/judge.py +++ /dev/null @@ -1,139 +0,0 @@ -import math - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.models import JsonKVStorage, NetworkXStorage -from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT -from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy - - -async def judge_statement( # pylint: disable=too-many-statements - trainee_llm_client: BaseLLMWrapper, - graph_storage: NetworkXStorage, - rephrase_storage: JsonKVStorage, - re_judge: bool = False, - progress_bar: gr.Progress = None, -) -> NetworkXStorage: - """ - Get all edges and nodes and judge them - - :param trainee_llm_client: judge the statements to get comprehension loss - :param graph_storage: graph storage instance - :param rephrase_storage: rephrase storage instance - :param re_judge: re-judge the relations - :param progress_bar - :return: - """ - - async def _judge_single_relation( - edge: tuple, - ): - source_id = edge[0] - target_id = edge[1] - edge_data = edge[2] - - if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None: - logger.debug( - "Edge %s -> %s already judged, loss: %s, skip", - source_id, - target_id, - edge_data["loss"], - ) - return source_id, target_id, edge_data - - description = edge_data["description"] - - try: - descriptions = rephrase_storage.get_by_id(description) - assert descriptions is not None - - judgements = [] - gts = [gt for _, gt in descriptions] - for description, gt in descriptions: - judgement = await trainee_llm_client.generate_topk_per_token( - STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) - ) - judgements.append(judgement[0].top_candidates) - - loss = yes_no_loss_entropy(judgements, gts) - - logger.debug( - "Edge %s -> %s description: %s loss: %s", - source_id, - target_id, - description, - loss, - ) - - edge_data["loss"] = loss - except Exception as e: # pylint: disable=broad-except - logger.error( - "Error in judging relation %s -> %s: %s", source_id, target_id, e - ) - logger.info("Use default loss 0.1") - edge_data["loss"] = -math.log(0.1) - - graph_storage.update_edge(source_id, target_id, edge_data) - return source_id, target_id, edge_data - - edges = graph_storage.get_all_edges() - - await run_concurrent( - _judge_single_relation, - edges, - desc="Judging relations", - unit="relation", - progress_bar=progress_bar, - ) - - async def _judge_single_entity( - node: tuple, - ): - node_id = node[0] - node_data = node[1] - - if (not re_judge) and "loss" in node_data and node_data["loss"] is not None: - logger.debug( - "Node %s already judged, loss: %s, skip", node_id, node_data["loss"] - ) - return node_id, node_data - - description = node_data["description"] - - try: - descriptions = rephrase_storage.get_by_id(description) - assert descriptions is not None - - judgements = [] - gts = [gt for _, gt in descriptions] - for description, gt in descriptions: - judgement = await trainee_llm_client.generate_topk_per_token( - STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description) - ) - judgements.append(judgement[0].top_candidates) - - loss = yes_no_loss_entropy(judgements, gts) - - logger.debug("Node %s description: %s loss: %s", node_id, description, loss) - - node_data["loss"] = loss - except Exception as e: # pylint: disable=broad-except - logger.error("Error in judging entity %s: %s", node_id, e) - logger.error("Use default loss 0.1") - node_data["loss"] = -math.log(0.1) - - graph_storage.update_node(node_id, node_data) - return node_id, node_data - - nodes = graph_storage.get_all_nodes() - - await run_concurrent( - _judge_single_entity, - nodes, - desc="Judging entities", - unit="entity", - progress_bar=progress_bar, - ) - - return graph_storage diff --git a/graphgen/operators/quiz_and_judge/quiz.py b/graphgen/operators/quiz_and_judge/quiz.py deleted file mode 100644 index 9aadb34b8e5f82749165aa6b5ff9dd7ef7f6fd20..0000000000000000000000000000000000000000 --- a/graphgen/operators/quiz_and_judge/quiz.py +++ /dev/null @@ -1,93 +0,0 @@ -from collections import defaultdict - -import gradio as gr - -from graphgen.bases import BaseLLMWrapper -from graphgen.models import JsonKVStorage, NetworkXStorage, QuizGenerator -from graphgen.utils import logger, run_concurrent - - -async def quiz( - synth_llm_client: BaseLLMWrapper, - graph_storage: NetworkXStorage, - rephrase_storage: JsonKVStorage, - max_samples: int = 1, - progress_bar: gr.Progress = None, -) -> JsonKVStorage: - """ - Get all edges and quiz them using QuizGenerator. - - :param synth_llm_client: generate statements - :param graph_storage: graph storage instance - :param rephrase_storage: rephrase storage instance - :param max_samples: max samples for each edge - :param progress_bar - :return: - """ - - generator = QuizGenerator(synth_llm_client) - - async def _process_single_quiz(item: tuple[str, str, str]): - description, template_type, gt = item - try: - # if rephrase_storage exists already, directly get it - descriptions = rephrase_storage.get_by_id(description) - if descriptions: - return None - - prompt = generator.build_prompt_for_description(description, template_type) - new_description = await synth_llm_client.generate_answer( - prompt, temperature=1 - ) - rephrased_text = generator.parse_rephrased_text(new_description) - return {description: [(rephrased_text, gt)]} - - except Exception as e: # pylint: disable=broad-except - logger.error("Error when quizzing description %s: %s", description, e) - return None - - edges = graph_storage.get_all_edges() - nodes = graph_storage.get_all_nodes() - - results = defaultdict(list) - items = [] - for edge in edges: - edge_data = edge[2] - description = edge_data["description"] - - results[description] = [(description, "yes")] - - for i in range(max_samples): - if i > 0: - items.append((description, "TEMPLATE", "yes")) - items.append((description, "ANTI_TEMPLATE", "no")) - - for node in nodes: - node_data = node[1] - description = node_data["description"] - - results[description] = [(description, "yes")] - - for i in range(max_samples): - if i > 0: - items.append((description, "TEMPLATE", "yes")) - items.append((description, "ANTI_TEMPLATE", "no")) - - quiz_results = await run_concurrent( - _process_single_quiz, - items, - desc="Quizzing descriptions", - unit="description", - progress_bar=progress_bar, - ) - - for new_result in quiz_results: - if new_result: - for key, value in new_result.items(): - results[key].extend(value) - - for key, value in results.items(): - results[key] = list(set(value)) - rephrase_storage.upsert({key: results[key]}) - - return rephrase_storage diff --git a/graphgen/operators/read/__init__.py b/graphgen/operators/read/__init__.py index 075ae9381c390177965bbf4a5c55b81447f1fa84..cda44587a72d79722261bed4203649ce29dd50a5 100644 --- a/graphgen/operators/read/__init__.py +++ b/graphgen/operators/read/__init__.py @@ -1 +1 @@ -from .read_files import read_files +from .read import read diff --git a/graphgen/operators/read/parallel_file_scanner.py b/graphgen/operators/read/parallel_file_scanner.py index 73b477c307a2afb0368605166044b79f05149785..84219139d4f88961c21329d33bfe576038f0f018 100644 --- a/graphgen/operators/read/parallel_file_scanner.py +++ b/graphgen/operators/read/parallel_file_scanner.py @@ -5,14 +5,13 @@ from pathlib import Path from typing import Any, Dict, List, Set, Union from graphgen.models import RocksDBCache -from graphgen.utils import logger class ParallelFileScanner: def __init__( self, cache_dir: str, allowed_suffix, rescan: bool = False, max_workers: int = 4 ): - self.cache = RocksDBCache(os.path.join(cache_dir, "file_paths_cache")) + self.cache = RocksDBCache(os.path.join(cache_dir, "input_paths.db")) self.allowed_suffix = set(allowed_suffix) if allowed_suffix else None self.rescan = rescan self.max_workers = max_workers @@ -32,15 +31,12 @@ class ParallelFileScanner: self._scan_files, Path(p).resolve(), recursive, set() ) future_to_path[future] = p - else: - logger.warning("[READ] Path does not exist: %s", p) for future in as_completed(future_to_path): path = future_to_path[future] try: results[path] = future.result() except Exception as e: - logger.error("[READ] Error scanning path %s: %s", path, e) results[path] = { "error": str(e), "files": [], @@ -56,17 +52,14 @@ class ParallelFileScanner: # Avoid cycles due to symlinks if path_str in visited: - logger.warning("[READ] Skipping already visited path: %s", path_str) return self._empty_result(path_str) # cache check cache_key = f"scan::{path_str}::recursive::{recursive}" cached = self.cache.get(cache_key) if cached and not self.rescan: - logger.info("[READ] Using cached scan result for path: %s", path_str) return cached["data"] - logger.info("[READ] Scanning path: %s", path_str) files, dirs = [], [] stats = {"total_size": 0, "file_count": 0, "dir_count": 0, "errors": 0} @@ -108,7 +101,6 @@ class ParallelFileScanner: stats["errors"] += 1 except (PermissionError, FileNotFoundError, OSError) as e: - logger.error("[READ] Failed to scan path %s: %s", path_str, e) return {"error": str(e), "files": [], "dirs": [], "stats": stats} if recursive: @@ -171,7 +163,6 @@ class ParallelFileScanner: try: results[path] = future.result() except Exception as e: - logger.error("[READ] Error scanning subdirectory %s: %s", path, e) results[path] = { "error": str(e), "files": [], @@ -183,18 +174,14 @@ class ParallelFileScanner: def _cache_result(self, key: str, result: Dict, path: Path): """Cache the scan result""" - try: - self.cache.set( - key, - { - "data": result, - "dir_mtime": path.stat().st_mtime, - "cached_at": time.time(), - }, - ) - logger.info("[READ] Cached scan result for path: %s", path) - except OSError as e: - logger.error("[READ] Failed to cache scan result for path %s: %s", path, e) + self.cache.set( + key, + { + "data": result, + "dir_mtime": path.stat().st_mtime, + "cached_at": time.time(), + }, + ) def _is_allowed_file(self, path: Path) -> bool: """Check if the file has an allowed suffix""" @@ -209,7 +196,6 @@ class ParallelFileScanner: keys = [k for k in self.cache if k.startswith(f"scan::{path}")] for k in keys: self.cache.delete(k) - logger.info("[READ] Invalidated cache for path: %s", path) def close(self): self.cache.close() diff --git a/graphgen/operators/read/read.py b/graphgen/operators/read/read.py new file mode 100644 index 0000000000000000000000000000000000000000..fbed377eb538f482e9756aea9ca980d4b4b5f0e6 --- /dev/null +++ b/graphgen/operators/read/read.py @@ -0,0 +1,128 @@ +from pathlib import Path +from typing import Any, List, Optional, Union + +import ray + +from graphgen.models import ( + CSVReader, + JSONReader, + ParquetReader, + PDFReader, + PickleReader, + RDFReader, + TXTReader, +) +from graphgen.utils import compute_mm_hash, logger + +from .parallel_file_scanner import ParallelFileScanner + +_MAPPING = { + "jsonl": JSONReader, + "json": JSONReader, + "txt": TXTReader, + "csv": CSVReader, + "md": TXTReader, + "pdf": PDFReader, + "parquet": ParquetReader, + "pickle": PickleReader, + "rdf": RDFReader, + "owl": RDFReader, + "ttl": RDFReader, +} + + +def _build_reader(suffix: str, cache_dir: str | None, **reader_kwargs): + """Factory function to build appropriate reader instance""" + suffix = suffix.lower() + reader_cls = _MAPPING.get(suffix) + if not reader_cls: + raise ValueError(f"Unsupported file suffix: {suffix}") + + # Special handling for PDFReader which needs output_dir + if suffix == "pdf": + if cache_dir is None: + raise ValueError("cache_dir must be provided for PDFReader") + return reader_cls(output_dir=cache_dir, **reader_kwargs) + + return reader_cls(**reader_kwargs) + + +def read( + input_path: Union[str, List[str]], + allowed_suffix: Optional[List[str]] = None, + cache_dir: Optional[str] = "cache", + parallelism: int = 4, + recursive: bool = True, + **reader_kwargs: Any, +) -> ray.data.Dataset: + """ + Unified entry point to read files of multiple types using Ray Data. + + :param input_path: File or directory path(s) to read from + :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) + :param cache_dir: Directory to cache intermediate files (PDF processing) + :param parallelism: Number of parallel workers + :param recursive: Whether to scan directories recursively + :param reader_kwargs: Additional kwargs passed to readers + :return: Ray Dataset containing all documents + """ + try: + # 1. Scan all paths to discover files + logger.info("[READ] Scanning paths: %s", input_path) + scanner = ParallelFileScanner( + cache_dir=cache_dir, + allowed_suffix=allowed_suffix, + rescan=False, + max_workers=parallelism if parallelism > 0 else 1, + ) + + all_files = [] + scan_results = scanner.scan(input_path, recursive=recursive) + + for result in scan_results.values(): + all_files.extend(result.get("files", [])) + + logger.info("[READ] Found %d files to process", len(all_files)) + + if not all_files: + raise ValueError("No files found to read.") + + # 2. Group files by suffix to use appropriate reader + files_by_suffix = {} + for file_info in all_files: + suffix = Path(file_info["path"]).suffix.lower().lstrip(".") + if allowed_suffix and suffix not in [ + s.lower().lstrip(".") for s in allowed_suffix + ]: + continue + files_by_suffix.setdefault(suffix, []).append(file_info["path"]) + + # 3. Create read tasks + read_tasks = [] + for suffix, file_paths in files_by_suffix.items(): + reader = _build_reader(suffix, cache_dir, **reader_kwargs) + ds = reader.read(file_paths) + read_tasks.append(ds) + + # 4. Combine all datasets + if not read_tasks: + raise ValueError("No datasets created from the provided files.") + + if len(read_tasks) == 1: + combined_ds = read_tasks[0] + else: + combined_ds = read_tasks[0].union(*read_tasks[1:]) + + combined_ds = combined_ds.map( + lambda record: { + **record, + "_doc_id": compute_mm_hash(record, prefix="doc-"), + } + ) + + logger.info("[READ] Successfully read files from %s", input_path) + return combined_ds + + except Exception as e: + logger.error("[READ] Failed to read files from %s: %s", input_path, e) + raise diff --git a/graphgen/operators/read/read_files.py b/graphgen/operators/read/read_files.py deleted file mode 100644 index d9e7f673790bf3f8f9a6d01ed9d30d8d500ed895..0000000000000000000000000000000000000000 --- a/graphgen/operators/read/read_files.py +++ /dev/null @@ -1,99 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional - -from graphgen.models import ( - CSVReader, - JSONLReader, - JSONReader, - ParquetReader, - PDFReader, - PickleReader, - RDFReader, - TXTReader, -) -from graphgen.utils import logger - -from .parallel_file_scanner import ParallelFileScanner - -_MAPPING = { - "jsonl": JSONLReader, - "json": JSONReader, - "txt": TXTReader, - "csv": CSVReader, - "md": TXTReader, - "pdf": PDFReader, - "parquet": ParquetReader, - "pickle": PickleReader, - "rdf": RDFReader, - "owl": RDFReader, - "ttl": RDFReader, -} - - -def _build_reader(suffix: str, cache_dir: str | None): - suffix = suffix.lower() - if suffix == "pdf" and cache_dir is not None: - return _MAPPING[suffix](output_dir=cache_dir) - return _MAPPING[suffix]() - - -def read_files( - input_file: str, - allowed_suffix: Optional[List[str]] = None, - cache_dir: Optional[str] = None, - max_workers: int = 4, - rescan: bool = False, -) -> Iterator[Dict[str, Any]]: - """ - Read files from a path using parallel scanning and appropriate readers. - - Args: - input_file: Path to a file or directory - allowed_suffix: List of file suffixes to read. If None, uses all supported types - cache_dir: Directory for caching PDF extraction and scan results - max_workers: Number of workers for parallel scanning - rescan: Whether to force rescan even if cached results exist - """ - - path = Path(input_file).expanduser() - if not path.exists(): - raise FileNotFoundError(f"input_path not found: {input_file}") - - if allowed_suffix is None: - support_suffix = set(_MAPPING.keys()) - else: - support_suffix = {s.lower().lstrip(".") for s in allowed_suffix} - - with ParallelFileScanner( - cache_dir=cache_dir or "cache", - allowed_suffix=support_suffix, - rescan=rescan, - max_workers=max_workers, - ) as scanner: - scan_results = scanner.scan(str(path), recursive=True) - - # Extract files from scan results - files_to_read = [] - for path_result in scan_results.values(): - if "error" in path_result: - logger.warning("Error scanning %s: %s", path_result.path, path_result.error) - continue - files_to_read.extend(path_result.get("files", [])) - - logger.info( - "Found %d eligible file(s) under folder %s (allowed_suffix=%s)", - len(files_to_read), - input_file, - support_suffix, - ) - - for file_info in files_to_read: - try: - file_path = file_info["path"] - suffix = Path(file_path).suffix.lstrip(".").lower() - reader = _build_reader(suffix, cache_dir) - - yield from reader.read(file_path) - - except Exception as e: # pylint: disable=broad-except - logger.exception("Error reading %s: %s", file_info.get("path"), e) diff --git a/graphgen/operators/split/__init__.py b/graphgen/operators/split/__init__.py deleted file mode 100644 index 2afc738dee1ad98bb72b6885dccab32c0628499b..0000000000000000000000000000000000000000 --- a/graphgen/operators/split/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .split_chunks import chunk_documents diff --git a/graphgen/operators/split/split_chunks.py b/graphgen/operators/split/split_chunks.py deleted file mode 100644 index 3f728e00d983738958bf84a9630819b8edf7d9e9..0000000000000000000000000000000000000000 --- a/graphgen/operators/split/split_chunks.py +++ /dev/null @@ -1,84 +0,0 @@ -from functools import lru_cache -from typing import Union - -from tqdm.asyncio import tqdm as tqdm_async - -from graphgen.models import ( - ChineseRecursiveTextSplitter, - RecursiveCharacterSplitter, - Tokenizer, -) -from graphgen.utils import compute_content_hash, detect_main_language - -_MAPPING = { - "en": RecursiveCharacterSplitter, - "zh": ChineseRecursiveTextSplitter, -} - -SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter] - - -@lru_cache(maxsize=None) -def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT: - cls = _MAPPING[language] - kwargs = dict(frozen_kwargs) - return cls(**kwargs) - - -def split_chunks(text: str, language: str = "en", **kwargs) -> list: - if language not in _MAPPING: - raise ValueError( - f"Unsupported language: {language}. " - f"Supported languages are: {list(_MAPPING.keys())}" - ) - frozen_kwargs = frozenset( - (k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items() - ) - splitter = _get_splitter(language, frozen_kwargs) - return splitter.split_text(text) - - -async def chunk_documents( - new_docs: dict, - tokenizer_instance: Tokenizer = None, - progress_bar=None, - **kwargs, -) -> dict: - inserting_chunks = {} - cur_index = 1 - doc_number = len(new_docs) - async for doc_key, doc in tqdm_async( - new_docs.items(), desc="[1/4]Chunking documents", unit="doc" - ): - doc_type = doc.get("type") - if doc_type == "text": - doc_language = detect_main_language(doc["content"]) - - text_chunks = split_chunks( - doc["content"], - language=doc_language, - **kwargs, - ) - - chunks = { - compute_content_hash(txt, prefix="chunk-"): { - "content": txt, - "type": "text", - "_full_docs_id": doc_key, - "length": len(tokenizer_instance.encode(txt)) - if tokenizer_instance - else len(txt), - "language": doc_language, - } - for txt in text_chunks - } - else: - chunks = {doc_key.replace("doc-", f"{doc_type}-"): {**doc}} - - inserting_chunks.update(chunks) - - if progress_bar is not None: - progress_bar(cur_index / doc_number, f"Chunking {doc_key}") - cur_index += 1 - - return inserting_chunks diff --git a/graphgen/operators/storage.py b/graphgen/operators/storage.py deleted file mode 100644 index ea5488ac3427722fe8ddc75ef5d12c80c4b6ea09..0000000000000000000000000000000000000000 --- a/graphgen/operators/storage.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -from typing import Any - -import ray - -from graphgen.models import JsonKVStorage, JsonListStorage, NetworkXStorage - - -@ray.remote -class StorageManager: - """ - Centralized storage for all operators - - Example Usage: - ---------- - # init - storage_manager = StorageManager.remote(working_dir="/path/to/dir", unique_id=123) - - # visit storage in tasks - @ray.remote - def some_task(storage_manager): - full_docs_storage = ray.get(storage_manager.get_storage.remote("full_docs")) - - # visit storage in other actors - @ray.remote - class SomeOperator: - def __init__(self, storage_manager): - self.storage_manager = storage_manager - def some_method(self): - full_docs_storage = ray.get(self.storage_manager.get_storage.remote("full_docs")) - """ - - def __init__(self, working_dir: str, unique_id: int): - self.working_dir = working_dir - self.unique_id = unique_id - - # Initialize all storage backends - self.storages = { - "full_docs": JsonKVStorage(working_dir, namespace="full_docs"), - "chunks": JsonKVStorage(working_dir, namespace="chunks"), - "graph": NetworkXStorage(working_dir, namespace="graph"), - "rephrase": JsonKVStorage(working_dir, namespace="rephrase"), - "partition": JsonListStorage(working_dir, namespace="partition"), - "search": JsonKVStorage( - os.path.join(working_dir, "data", "graphgen", f"{unique_id}"), - namespace="search", - ), - "extraction": JsonKVStorage( - os.path.join(working_dir, "data", "graphgen", f"{unique_id}"), - namespace="extraction", - ), - "qa": JsonListStorage( - os.path.join(working_dir, "data", "graphgen", f"{unique_id}"), - namespace="qa", - ), - } - - def get_storage(self, name: str) -> Any: - return self.storages.get(name) diff --git a/graphgen/run.py b/graphgen/run.py index c300a6aa56631b48bb510fbcdc8d540632fea3f4..5ae34d7eac0b60fac490f9de0ef9a8a4a91cc3e8 100644 --- a/graphgen/run.py +++ b/graphgen/run.py @@ -1,14 +1,18 @@ import argparse import os import time -from importlib.resources import files +from importlib import resources +from typing import Any, Dict +import ray import yaml from dotenv import load_dotenv +from ray.data.block import Block +from ray.data.datasource.filename_provider import FilenameProvider -from graphgen.engine import Context, Engine, collect_ops -from graphgen.graphgen import GraphGen -from graphgen.utils import logger, set_logger +from graphgen.engine import Engine +from graphgen.operators import operators +from graphgen.utils import CURRENT_LOGGER_VAR, logger, set_logger sys_path = os.path.abspath(os.path.dirname(__file__)) @@ -28,12 +32,38 @@ def save_config(config_path, global_config): ) +class NodeFilenameProvider(FilenameProvider): + def __init__(self, node_id: str): + self.node_id = node_id + + def get_filename_for_block( + self, block: Block, write_uuid: str, task_index: int, block_index: int + ) -> str: + # format: {node_id}_{write_uuid}_{task_index:06}_{block_index:06}.json + return f"{self.node_id}_{write_uuid}_{task_index:06d}_{block_index:06d}.jsonl" + + def get_filename_for_row( + self, + row: Dict[str, Any], + write_uuid: str, + task_index: int, + block_index: int, + row_index: int, + ) -> str: + raise NotImplementedError( + f"Row-based filenames are not supported by write_json. " + f"Node: {self.node_id}, write_uuid: {write_uuid}" + ) + + def main(): parser = argparse.ArgumentParser() parser.add_argument( "--config_file", help="Config parameters for GraphGen.", - default=files("graphgen").joinpath("configs", "aggregated_config.yaml"), + default=resources.files("graphgen") + .joinpath("configs") + .joinpath("aggregated_config.yaml"), type=str, ) parser.add_argument( @@ -52,28 +82,38 @@ def main(): config = yaml.load(f, Loader=yaml.FullLoader) unique_id = int(time.time()) - - output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}") + output_path = os.path.join(working_dir, "output", f"{unique_id}") set_working_dir(output_path) - - set_logger( - os.path.join(output_path, f"{unique_id}.log"), + log_path = os.path.join(working_dir, "logs", "Driver.log") + driver_logger = set_logger( + log_path, + name="GraphGen", if_stream=True, ) + CURRENT_LOGGER_VAR.set(driver_logger) logger.info( "GraphGen with unique ID %s logging to %s", unique_id, - os.path.join(working_dir, f"{unique_id}.log"), + log_path, ) - graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir) - - # share context between different steps - ctx = Context(config=config, graph_gen=graph_gen) - ops = collect_ops(config, graph_gen) - - # run operations - Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx) + engine = Engine(config, operators) + ds = ray.data.from_items([]) + results = engine.execute(ds) + + for node_id, dataset in results.items(): + output_path = os.path.join(output_path, f"{node_id}") + os.makedirs(output_path, exist_ok=True) + dataset.write_json( + output_path, + filename_provider=NodeFilenameProvider(node_id), + pandas_json_args_fn=lambda: { + "force_ascii": False, + "orient": "records", + "lines": True, + }, + ) + logger.info("Node %s results saved to %s", node_id, output_path) save_config(os.path.join(output_path, "config.yaml"), config) logger.info("GraphGen completed successfully. Data saved to %s", output_path) diff --git a/graphgen/utils/__init__.py b/graphgen/utils/__init__.py index d3e6df7b6f523e7e4504d8c533a7cba669688c28..ec118816674640ccfe13467ff5b1707e907f6569 100644 --- a/graphgen/utils/__init__.py +++ b/graphgen/utils/__init__.py @@ -16,7 +16,7 @@ from .hash import ( compute_mm_hash, ) from .help_nltk import NLTKHelper -from .log import logger, parse_log, set_logger +from .log import CURRENT_LOGGER_VAR, logger, set_logger from .loop import create_event_loop from .run_concurrent import run_concurrent from .wrap import async_to_sync_method diff --git a/graphgen/utils/log.py b/graphgen/utils/log.py index 102b7b23d8d7b793301872881329086b0b9e4d1d..e29e994e86b8589320aeccdabdb8f3d19296a74e 100644 --- a/graphgen/utils/log.py +++ b/graphgen/utils/log.py @@ -1,13 +1,15 @@ +import contextvars import logging +import os from logging.handlers import RotatingFileHandler +from typing import Any from rich.logging import RichHandler -logger = logging.getLogger("graphgen") - def set_logger( log_file: str, + name: str, file_level: int = logging.DEBUG, console_level: int = logging.INFO, *, @@ -17,26 +19,27 @@ def set_logger( force: bool = False, ): - if logger.hasHandlers() and not force: - return + current_logger = logging.getLogger(name) + if current_logger.hasHandlers() and not force: + return current_logger if force: - logger.handlers.clear() + current_logger.handlers.clear() - logger.setLevel( + current_logger.setLevel( min(file_level, console_level) ) # Set to the lowest level to capture all logs - logger.propagate = False + current_logger.propagate = False - if logger.handlers: - logger.handlers.clear() + if log_file: + os.makedirs(os.path.dirname(log_file), exist_ok=True) if if_stream: console = RichHandler( level=console_level, show_path=False, rich_tracebacks=True ) console.setFormatter(logging.Formatter("%(message)s")) - logger.addHandler(console) + current_logger.addHandler(console) file_handler = RotatingFileHandler( log_file, @@ -51,10 +54,48 @@ def set_logger( datefmt="%y-%m-%d %H:%M:%S", ) ) - logger.addHandler(file_handler) + current_logger.addHandler(file_handler) + return current_logger + + +CURRENT_LOGGER_VAR = contextvars.ContextVar("current_logger") + + +def get_current_logger() -> logging.Logger: + current_logger = CURRENT_LOGGER_VAR.get() + if not current_logger: + raise RuntimeError("No logger is set in the current context.") + return current_logger + + +class ContextAwareLogger: + @staticmethod + def _get_logger() -> logging.Logger: + return get_current_logger() + + def debug(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().debug(msg, *args, **kwargs) + + def info(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().info(msg, *args, **kwargs) + + def warning(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().warning(msg, *args, **kwargs) + + def error(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().error(msg, *args, **kwargs) + + def exception(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().exception(msg, *args, **kwargs) + + def critical(self, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().critical(msg, *args, **kwargs) + + def log(self, level: int, msg: object, *args: Any, **kwargs: Any) -> None: + self._get_logger().log(level, msg, *args, **kwargs) + + def __getattr__(self, name: str) -> Any: + return getattr(self._get_logger(), name) -def parse_log(log_file: str): - with open(log_file, "r", encoding="utf-8") as f: - lines = f.readlines() - return lines +logger = ContextAwareLogger() diff --git a/graphgen/utils/run_concurrent.py b/graphgen/utils/run_concurrent.py index ac63f87bd4ab1a897c61dd6d5ddd3c43b3ffb830..d1a9b0e261cccedd76ba130062e2d665b1f57622 100644 --- a/graphgen/utils/run_concurrent.py +++ b/graphgen/utils/run_concurrent.py @@ -1,55 +1,44 @@ import asyncio -from typing import Awaitable, Callable, List, Optional, TypeVar +from typing import Awaitable, Callable, List, TypeVar -import gradio as gr from tqdm.asyncio import tqdm as tqdm_async from graphgen.utils.log import logger +from .loop import create_event_loop + T = TypeVar("T") R = TypeVar("R") -async def run_concurrent( +def run_concurrent( coro_fn: Callable[[T], Awaitable[R]], items: List[T], *, desc: str = "processing", unit: str = "item", - progress_bar: Optional[gr.Progress] = None, ) -> List[R]: - tasks = [asyncio.create_task(coro_fn(it)) for it in items] - - completed_count = 0 - results = [] - - pbar = tqdm_async(total=len(items), desc=desc, unit=unit) - - if progress_bar is not None: - progress_bar(0.0, desc=f"{desc} (0/{len(items)})") - - for future in asyncio.as_completed(tasks): - try: - result = await future - results.append(result) - except Exception as e: # pylint: disable=broad-except - logger.exception("Task failed: %s", e) - # even if failed, record it to keep results consistent with tasks - results.append(e) - - completed_count += 1 - pbar.update(1) - - if progress_bar is not None: - progress = completed_count / len(items) - progress_bar(progress, desc=f"{desc} ({completed_count}/{len(items)})") - - pbar.close() - - if progress_bar is not None: - progress_bar(1.0, desc=f"{desc} (completed)") - - # filter out exceptions - results = [res for res in results if not isinstance(res, Exception)] - - return results + async def _run_all(): + tasks = [asyncio.create_task(coro_fn(item)) for item in items] + + results = [] + pbar = tqdm_async(total=len(items), desc=desc, unit=unit) + + for future in asyncio.as_completed(tasks): + try: + result = await future + results.append(result) + except Exception as e: + logger.exception("Task failed: %s", e) + results.append(e) + + pbar.update(1) + + pbar.close() + return [res for res in results if not isinstance(res, Exception)] + + loop = create_event_loop() + try: + return loop.run_until_complete(_run_all()) + finally: + loop.close() diff --git a/requirements.txt b/requirements.txt index 85fc43e3adda2529b435be7eb064abaf3dd8a4a5..44079ab5f04eca1a34d8adac53c8c29bf2df1114 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,8 @@ fastapi trafilatura aiohttp socksio +pydantic +ray==2.52.1 leidenalg igraph diff --git a/webui/app.py b/webui/app.py index dfd0edda5b989cea18c100802e8e7a4e0b70df8e..98b02601d90512c6a44703f8ff8e70d1963c4c1e 100644 --- a/webui/app.py +++ b/webui/app.py @@ -5,14 +5,12 @@ import tempfile from importlib.resources import files import gradio as gr -import pandas as pd +import ray from dotenv import load_dotenv -from graphgen.engine import Context, Engine, collect_ops -from graphgen.graphgen import GraphGen -from graphgen.models import OpenAIClient, Tokenizer -from graphgen.models.llm.limitter import RPM, TPM -from graphgen.utils import set_logger +from graphgen.engine import Engine +from graphgen.operators import operators +from graphgen.utils import CURRENT_LOGGER_VAR, set_logger from webui.base import WebuiParams from webui.i18n import Translate from webui.i18n import gettext as _ @@ -22,7 +20,6 @@ from webui.utils import cleanup_workspace, count_tokens, preview_file, setup_wor root_dir = files("webui").parent sys.path.append(root_dir) - load_dotenv() css = """ @@ -34,131 +31,136 @@ css = """ """ -def init_graph_gen(config: dict, env: dict) -> GraphGen: - # Set up working directory - log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache")) - set_logger(log_file, if_stream=True) - os.environ.update({k: str(v) for k, v in env.items()}) - - tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base")) - synthesizer_llm_client = OpenAIClient( - model=env.get("SYNTHESIZER_MODEL", ""), - base_url=env.get("SYNTHESIZER_BASE_URL", ""), - api_key=env.get("SYNTHESIZER_API_KEY", ""), - request_limit=True, - rpm=RPM(env.get("RPM", 1000)), - tpm=TPM(env.get("TPM", 50000)), - tokenizer=tokenizer_instance, - ) - trainee_llm_client = OpenAIClient( - model=env.get("TRAINEE_MODEL", ""), - base_url=env.get("TRAINEE_BASE_URL", ""), - api_key=env.get("TRAINEE_API_KEY", ""), - request_limit=True, - rpm=RPM(env.get("RPM", 1000)), - tpm=TPM(env.get("TPM", 50000)), - tokenizer=tokenizer_instance, - ) - - graph_gen = GraphGen( - working_dir=working_dir, - tokenizer_instance=tokenizer_instance, - synthesizer_llm_client=synthesizer_llm_client, - trainee_llm_client=trainee_llm_client, - ) - - return graph_gen - - -# pylint: disable=too-many-statements -def run_graphgen(params: WebuiParams, progress=gr.Progress()): - def sum_tokens(client): - return sum(u["total_tokens"] for u in client.token_usage) - +def _get_partition_params(params: WebuiParams): method = params.partition_method if method == "dfs": - partition_params = { + return { "max_units_per_community": params.dfs_max_units, } - elif method == "bfs": - partition_params = { + if method == "bfs": + return { "max_units_per_community": params.bfs_max_units, } - elif method == "leiden": - partition_params = { + if method == "leiden": + return { "max_size": params.leiden_max_size, "use_lcc": params.leiden_use_lcc, "random_seed": params.leiden_random_seed, } - else: # ece - partition_params = { - "max_units_per_community": params.ece_max_units, - "min_units_per_community": params.ece_min_units, - "max_tokens_per_community": params.ece_max_tokens, - "unit_sampling": params.ece_unit_sampling, - } + # ece + return { + "max_units_per_community": params.ece_max_units, + "min_units_per_community": params.ece_min_units, + "max_tokens_per_community": params.ece_max_tokens, + "unit_sampling": params.ece_unit_sampling, + } + + +# pylint: disable=too-many-statements +def run_graphgen(params: WebuiParams, progress=gr.Progress()): + # 1. Setup Workspace + log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache")) + driver_logger = set_logger(log_file, "GraphGeb", if_stream=True) + CURRENT_LOGGER_VAR.set(driver_logger) + + # 2. Setup Environment Variables for Ray Actors/LLM Init + # The refactored code relies on env vars in graphgen/common/init_llm.py + os.environ["SYNTHESIZER_BACKEND"] = "openai_api" # Assuming OpenAI compatible API + os.environ["SYNTHESIZER_BASE_URL"] = params.synthesizer_url + os.environ["SYNTHESIZER_API_KEY"] = params.api_key + os.environ["SYNTHESIZER_MODEL"] = params.synthesizer_model + os.environ["RPM"] = str(params.rpm) + os.environ["TPM"] = str(params.tpm) + os.environ["TOKENIZER_MODEL"] = params.tokenizer + + if params.if_trainee_model: + os.environ["TRAINEE_BACKEND"] = "openai_api" + os.environ["TRAINEE_BASE_URL"] = params.trainee_url + os.environ["TRAINEE_API_KEY"] = params.trainee_api_key + os.environ["TRAINEE_MODEL"] = params.trainee_model - pipeline = [ + # 3. Construct Pipeline Configuration (DAG) + nodes = [ { - "name": "read", - "op_key": "read", + "id": "read", + "op_name": "read", + "type": "source", + "dependencies": [], "params": { - "input_file": params.upload_file, + "input_path": [params.upload_file], }, }, { - "name": "chunk", - "deps": ["read"], - "op_key": "chunk", + "id": "chunk", + "op_name": "chunk", + "type": "map_batch", + "dependencies": ["read"], + "execution_params": {"replicas": 1}, "params": { "chunk_size": params.chunk_size, "chunk_overlap": params.chunk_overlap, }, }, { - "name": "build_kg", - "deps": ["chunk"], - "op_key": "build_kg", + "id": "build_kg", + "op_name": "build_kg", + "type": "map_batch", + "dependencies": ["chunk"], + "execution_params": {"replicas": 1, "batch_size": 128}, }, ] + last_node_id = "build_kg" + + # Optional: Quiz and Judge if params.if_trainee_model: - pipeline.append( - { - "name": "quiz_and_judge", - "deps": ["build_kg"], - "op_key": "quiz_and_judge", - "params": {"quiz_samples": params.quiz_samples, "re_judge": True}, - } - ) - pipeline.append( + nodes.append( { - "name": "partition", - "deps": ["quiz_and_judge"], - "op_key": "partition", + "id": "quiz", + "op_name": "quiz", + "type": "aggregate", # QuizService uses aggregate in config + "dependencies": ["build_kg"], + "execution_params": {"replicas": 1, "batch_size": 128}, "params": { - "method": params.partition_method, - "method_params": partition_params, + "quiz_samples": params.quiz_samples, + "concurrency_limit": 200, }, } ) - else: - pipeline.append( + + nodes.append( { - "name": "partition", - "deps": ["build_kg"], - "op_key": "partition", - "params": { - "method": params.partition_method, - "method_params": partition_params, - }, + "id": "judge", + "op_name": "judge", + "type": "map_batch", + "dependencies": ["quiz"], + "execution_params": {"replicas": 1, "batch_size": 128}, } ) - pipeline.append( + last_node_id = "judge" + + # Node: Partition + nodes.append( + { + "id": "partition", + "op_name": "partition", + "type": "aggregate", # PartitionService uses aggregate + "dependencies": [last_node_id], + "params": { + "method": params.partition_method, + "method_params": _get_partition_params(params), + }, + } + ) + + # Node: Generate + nodes.append( { - "name": "generate", - "deps": ["partition"], - "op_key": "generate", + "id": "generate", + "op_name": "generate", + "type": "map_batch", + "dependencies": ["partition"], + "execution_params": {"replicas": 1, "batch_size": 128}, "params": { "method": params.mode, "data_format": params.data_format, @@ -166,88 +168,50 @@ def run_graphgen(params: WebuiParams, progress=gr.Progress()): } ) - config = { - "if_trainee_model": params.if_trainee_model, - "read": {"input_file": params.upload_file}, - "pipeline": pipeline, - } + config = {"global_params": {"working_dir": working_dir}, "nodes": nodes} - env = { - "TOKENIZER_MODEL": params.tokenizer, - "SYNTHESIZER_BASE_URL": params.synthesizer_url, - "SYNTHESIZER_MODEL": params.synthesizer_model, - "TRAINEE_BASE_URL": params.trainee_url, - "TRAINEE_MODEL": params.trainee_model, - "SYNTHESIZER_API_KEY": params.api_key, - "TRAINEE_API_KEY": params.trainee_api_key, - "RPM": params.rpm, - "TPM": params.tpm, - } + try: + # 4. Initialize and Run Engine + # Initialize Ray if not already running (Engine handles this mostly, but good for safety) + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True, log_to_driver=True) - # Test API connection - test_api_connection( - env["SYNTHESIZER_BASE_URL"], - env["SYNTHESIZER_API_KEY"], - env["SYNTHESIZER_MODEL"], - ) - if config["if_trainee_model"]: - test_api_connection( - env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"] - ) + engine = Engine(config, operators) - # Initialize GraphGen - graph_gen = init_graph_gen(config, env) - graph_gen.clear() - graph_gen.progress_bar = progress + # Start with an empty dataset to kick off the pipeline + ds = ray.data.from_items([]) - try: - ctx = Context(config=config, graph_gen=graph_gen) - ops = collect_ops(config, graph_gen) - Engine(max_workers=config.get("max_workers", 4)).run(ops, ctx) - - # Save output - output_data = graph_gen.qa_storage.data - with tempfile.NamedTemporaryFile( - mode="w", suffix=".jsonl", delete=False, encoding="utf-8" - ) as tmpfile: - json.dump(output_data, tmpfile, ensure_ascii=False) - output_file = tmpfile.name - - synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client) - trainee_tokens = ( - sum_tokens(graph_gen.trainee_llm_client) - if config["if_trainee_model"] - else 0 - ) - total_tokens = synthesizer_tokens + trainee_tokens - - data_frame = params.token_counter - try: - _update_data = [ - [data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)] - ] - new_df = pd.DataFrame(_update_data, columns=data_frame.columns) - data_frame = new_df - - except Exception as e: - raise gr.Error(f"DataFrame operation error: {str(e)}") - - return output_file, gr.DataFrame( - label="Token Stats", - headers=["Source Text Token Count", "Expected Token Usage", "Token Used"], - datatype="str", - interactive=False, - value=data_frame, - visible=True, - wrap=True, - ) + # Execute pipeline + results = engine.execute(ds) + + # 5. Process Output + # Extract the result from the 'generate' node + if "generate" in results: + result_ds = results["generate"] + + # Create a temporary file to save the output + with tempfile.NamedTemporaryFile( + mode="w", suffix=".jsonl", delete=False, encoding="utf-8" + ) as tmpfile: + # Iterate over rows and write to file + for row in result_ds.iter_rows(): + json.dump(row, tmpfile, ensure_ascii=False) + tmpfile.write("\n") + output_file = tmpfile.name + else: + raise gr.Error("Generation step failed to produce output.") + + # Note: Dynamic token counting from distributed actors is not directly available + # via client properties in the new architecture. We return the estimated stats from input. + + return output_file, params.token_counter except Exception as e: # pylint: disable=broad-except raise gr.Error(f"Error occurred: {str(e)}") finally: # Clean up workspace - cleanup_workspace(graph_gen.working_dir) + cleanup_workspace(working_dir) # Optional: keep for debugging or enable with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo: @@ -267,7 +231,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo: ("简体中文", "zh"), ], value="en", - # label=_("Language"), render=False, container=False, elem_classes=["center-row"], @@ -295,7 +258,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo: os.path.join(root_dir, "webui", "translation.json"), lang_btn, placeholder_langs=["en", "zh"], - persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0 + persistant=False, ): lang_btn.render() @@ -701,7 +664,6 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo: outputs=[output, token_counter], ) - if __name__ == "__main__": demo.queue(api_open=False, default_concurrency_limit=2) demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)