diff --git a/app.py b/app.py index 1179a7d0fbd5e8ad72b12ea9585a6690e348e978..7af93c3d970e0fe6c1f62a9a3f3cf234d5562e73 100644 --- a/app.py +++ b/app.py @@ -468,7 +468,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo: label="TPM", minimum=5000, maximum=5000000, - value=50000, + value=100000, step=1000, interactive=True, visible=True, diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py index 30b00144e2b66d9e9659af662e9e0210b507f866..ace331d571e419f84137c2ea97d314a4658a7e51 100644 --- a/graphgen/bases/__init__.py +++ b/graphgen/bases/__init__.py @@ -1,5 +1,7 @@ +from .base_generator import BaseGenerator from .base_kg_builder import BaseKGBuilder from .base_llm_client import BaseLLMClient +from .base_partitioner import BasePartitioner from .base_reader import BaseReader from .base_splitter import BaseSplitter from .base_storage import ( diff --git a/graphgen/bases/base_generator.py b/graphgen/bases/base_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..271f20f238bd7693669d62e3a1db852169ddebb8 --- /dev/null +++ b/graphgen/bases/base_generator.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any + +from graphgen.bases.base_llm_client import BaseLLMClient + + +@dataclass +class BaseGenerator(ABC): + """ + Generate QAs based on given prompts. + """ + + llm_client: BaseLLMClient + + @staticmethod + @abstractmethod + def build_prompt( + batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] + ) -> str: + """Build prompt for LLM based on the given batch""" + + @staticmethod + @abstractmethod + def parse_response(response: str) -> Any: + """Parse the LLM response and return the generated QAs""" + + async def generate( + self, + batch: tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] + ], + ) -> dict[str, Any]: + """ + Generate QAs based on a given batch. + :param batch + :return: QA pairs + """ + result = {} + prompt = self.build_prompt(batch) + response = await self.llm_client.generate_answer(prompt) + qa_pairs = self.parse_response(response) # generate one or more QA pairs + result.update(qa_pairs) + return result + + @staticmethod + def format_generation_results( + results: list[dict], output_data_format: str + ) -> list[dict[str, Any]]: + if output_data_format == "Alpaca": + results = [ + { + "instruction": v["question"], + "input": "", + "output": v["answer"], + } + for item in results + for k, v in item.items() + ] + elif output_data_format == "Sharegpt": + results = [ + { + "conversations": [ + {"from": "human", "value": v["question"]}, + {"from": "gpt", "value": v["answer"]}, + ] + } + for item in results + for k, v in item.items() + ] + elif output_data_format == "ChatML": + results = [ + { + "messages": [ + {"role": "user", "content": v["question"]}, + {"role": "assistant", "content": v["answer"]}, + ] + } + for item in results + for k, v in item.items() + ] + else: + raise ValueError(f"Unknown output data format: {output_data_format}") + return results diff --git a/graphgen/bases/base_partitioner.py b/graphgen/bases/base_partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..a3739e541ec26a2fc727c3f0d89cc4fc73767423 --- /dev/null +++ b/graphgen/bases/base_partitioner.py @@ -0,0 +1,76 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, List + +from graphgen.bases.base_storage import BaseGraphStorage +from graphgen.bases.datatypes import Community + + +@dataclass +class BasePartitioner(ABC): + @abstractmethod + async def partition( + self, + g: BaseGraphStorage, + **kwargs: Any, + ) -> List[Community]: + """ + Graph -> Communities + :param g: Graph storage instance + :param kwargs: Additional parameters for partitioning + :return: List of communities + """ + + @staticmethod + async def community2batch( + communities: List[Community], g: BaseGraphStorage + ) -> list[ + tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] + ] + ]: + """ + Convert communities to batches of nodes and edges. + :param communities + :param g: Graph storage instance + :return: List of batches, each batch is a tuple of (nodes, edges) + """ + batches = [] + for comm in communities: + nodes = comm.nodes + edges = comm.edges + nodes_data = [] + for node in nodes: + node_data = await g.get_node(node) + if node_data: + nodes_data.append((node, node_data)) + edges_data = [] + for u, v in edges: + edge_data = await g.get_edge(u, v) + if edge_data: + edges_data.append((u, v, edge_data)) + else: + edge_data = await g.get_edge(v, u) + if edge_data: + edges_data.append((v, u, edge_data)) + batches.append((nodes_data, edges_data)) + return batches + + @staticmethod + def _build_adjacency_list( + nodes: List[tuple[str, dict]], edges: List[tuple[str, str, dict]] + ) -> tuple[dict[str, List[str]], set[tuple[str, str]]]: + """ + Build adjacency list and edge set from nodes and edges. + :param nodes + :param edges + :return: adjacency list, edge set + """ + adj: dict[str, List[str]] = {n[0]: [] for n in nodes} + edge_set: set[tuple[str, str]] = set() + for e in edges: + adj[e[0]].append(e[1]) + adj[e[1]].append(e[0]) + edge_set.add((e[0], e[1])) + edge_set.add((e[1], e[0])) + return adj, edge_set diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index dff837781059657ec7e5eef7583516fe563b521d..6968dca21229759903e72fbca45bfc8961820913 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -78,7 +78,7 @@ class BaseGraphStorage(StorageNameSpace): async def update_node(self, node_id: str, node_data: dict[str, str]): raise NotImplementedError - async def get_all_nodes(self) -> Union[list[dict], None]: + async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]: raise NotImplementedError async def get_edge( @@ -91,7 +91,7 @@ class BaseGraphStorage(StorageNameSpace): ): raise NotImplementedError - async def get_all_edges(self) -> Union[list[dict], None]: + async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]: raise NotImplementedError async def get_node_edges( diff --git a/graphgen/bases/datatypes.py b/graphgen/bases/datatypes.py index 5a3212629a0aa9b333870339580ed0aa4f5c28c3..beb73a77fc2ab7e3bb027c014e85f2406e4a3c3e 100644 --- a/graphgen/bases/datatypes.py +++ b/graphgen/bases/datatypes.py @@ -30,3 +30,11 @@ class Token: @property def logprob(self) -> float: return math.log(self.prob) + + +@dataclass +class Community: + id: Union[int, str] + nodes: List[str] = field(default_factory=list) + edges: List[tuple] = field(default_factory=list) + metadata: dict = field(default_factory=dict) diff --git a/graphgen/configs/aggregated_config.yaml b/graphgen/configs/aggregated_config.yaml index 2809ca7773c4448ab89eb8fafc9aa6e6796c2967..78ca19effb595060f091f738038128cfabf29fe3 100644 --- a/graphgen/configs/aggregated_config.yaml +++ b/graphgen/configs/aggregated_config.yaml @@ -13,14 +13,10 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points partition: # graph partition configuration method: ece # ece is a custom partition method based on comprehension loss method_params: - bidirectional: true # whether to traverse the graph in both directions - edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss - expand_method: max_width # expand method, support: max_width, max_depth - isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add - max_depth: 5 # maximum depth for graph traversal - max_extra_edges: 20 # max edges per direction (if expand_method="max_width") - max_tokens: 256 # restricts input length (if expand_method="max_tokens") - loss_strategy: only_edge # defines loss computation focus, support: only_edge, both + max_units_per_community: 20 # max nodes and edges per community + min_units_per_community: 5 # min nodes and edges per community + max_tokens_per_community: 10240 # max tokens per community + unit_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss generate: mode: aggregated # atomic, aggregated, multi_hop, cot data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/atomic_config.yaml b/graphgen/configs/atomic_config.yaml index 90037ec312dba26eb86c7cc99fd79906d066edf1..d50ea421e38d15e25a17bea20f95a246c7a96201 100644 --- a/graphgen/configs/atomic_config.yaml +++ b/graphgen/configs/atomic_config.yaml @@ -11,16 +11,9 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points quiz_samples: 2 # number of quiz samples to generate re_judge: false # whether to re-judge the existing quiz samples partition: # graph partition configuration - method: ece # ece is a custom partition method based on comprehension loss + method: dfs # partition method, support: dfs, bfs, ece, leiden method_params: - bidirectional: true # whether to traverse the graph in both directions - edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss - expand_method: max_width # expand method, support: max_width, max_depth - isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add - max_depth: 3 # maximum depth for graph traversal - max_extra_edges: 5 # max edges per direction (if expand_method="max_width") - max_tokens: 256 # restricts input length (if expand_method="max_tokens") - loss_strategy: only_edge # defines loss computation focus, support: only_edge, both + max_units_per_community: 1 # atomic partition, one node or edge per community generate: mode: atomic # atomic, aggregated, multi_hop, cot data_format: Alpaca # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/cot_config.yaml b/graphgen/configs/cot_config.yaml index 69d1e608ae494a09129bcf0b1911bf307d315ea1..87dd346229f0de6df12b9362ad19f27506440aa8 100644 --- a/graphgen/configs/cot_config.yaml +++ b/graphgen/configs/cot_config.yaml @@ -9,11 +9,11 @@ search: # web search configuration quiz_and_judge: # quiz and test whether the LLM masters the knowledge points enabled: false partition: # graph partition configuration - method: leiden # leiden is a community detection algorithm + method: leiden # leiden is a partitioner detection algorithm method_params: max_size: 20 # Maximum size of communities - use_lcc: false - random_seed: 42 + use_lcc: false # whether to use the largest connected component + random_seed: 42 # random seed for partitioning generate: mode: cot # atomic, aggregated, multi_hop, cot data_format: Sharegpt # Alpaca, Sharegpt, ChatML diff --git a/graphgen/configs/multi_hop_config.yaml b/graphgen/configs/multi_hop_config.yaml index 1754cec48d195316b164e8c190f9c68e429062a0..09f0c0861c27f087e99019b103042c58632d5267 100644 --- a/graphgen/configs/multi_hop_config.yaml +++ b/graphgen/configs/multi_hop_config.yaml @@ -13,14 +13,10 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points partition: # graph partition configuration method: ece # ece is a custom partition method based on comprehension loss method_params: - bidirectional: true # whether to traverse the graph in both directions - edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss - expand_method: max_width # expand method, support: max_width, max_depth - isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add - max_depth: 1 # maximum depth for graph traversal - max_extra_edges: 2 # max edges per direction (if expand_method="max_width") - max_tokens: 256 # restricts input length (if expand_method="max_tokens") - loss_strategy: only_edge # defines loss computation focus, support: only_edge, both + max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3 + min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3 + max_tokens_per_community: 10240 # max tokens per community + unit_sampling: random # edge sampling strategy, support: random, max_loss, min_loss generate: mode: multi_hop # strategy for generating multi-hop QA pairs data_format: ChatML # Alpaca, Sharegpt, ChatML diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index a0dac1c7cab0125402b86cbb524589f10a57b320..2eb953e934fba516ca4475ce1ecc792e32f35e1f 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -18,21 +18,14 @@ from graphgen.models import ( from graphgen.operators import ( build_kg, chunk_documents, - generate_cot, + generate_qas, judge_statement, + partition_kg, quiz, read_files, search_all, - traverse_graph_for_aggregated, - traverse_graph_for_atomic, - traverse_graph_for_multi_hop, -) -from graphgen.utils import ( - async_to_sync_method, - compute_content_hash, - format_generation_results, - logger, ) +from graphgen.utils import async_to_sync_method, compute_content_hash, logger sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -238,51 +231,20 @@ class GraphGen: @async_to_sync_method async def generate(self, partition_config: Dict, generate_config: Dict): # Step 1: partition the graph - # TODO: implement graph partitioning, e.g. Partitioner().partition(self.graph_storage) - mode = generate_config["mode"] - if mode == "atomic": - results = await traverse_graph_for_atomic( - self.synthesizer_llm_client, - self.tokenizer_instance, - self.graph_storage, - partition_config["method_params"], - self.text_chunks_storage, - self.progress_bar, - ) - elif mode == "multi_hop": - results = await traverse_graph_for_multi_hop( - self.synthesizer_llm_client, - self.tokenizer_instance, - self.graph_storage, - partition_config["method_params"], - self.text_chunks_storage, - self.progress_bar, - ) - elif mode == "aggregated": - results = await traverse_graph_for_aggregated( - self.synthesizer_llm_client, - self.tokenizer_instance, - self.graph_storage, - partition_config["method_params"], - self.text_chunks_storage, - self.progress_bar, - ) - elif mode == "cot": - results = await generate_cot( - self.graph_storage, - self.synthesizer_llm_client, - method_params=partition_config["method_params"], - ) - else: - raise ValueError(f"Unknown generation mode: {mode}") - # Step 2: generate QA pairs - # TODO + batches = await partition_kg( + self.graph_storage, self.tokenizer_instance, partition_config + ) - # Step 3: format - results = format_generation_results( - results, output_data_format=generate_config["data_format"] + # Step 2: generate QA pairs + results = await generate_qas( + self.synthesizer_llm_client, batches, generate_config ) + if not results: + logger.warning("No QA pairs generated") + return + + # Step 3: store the generated QA pairs await self.qa_storage.upsert(results) await self.qa_storage.index_done_callback() diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 3ea152fa284fe687c5a0fb52ac6a613326eb1f38..d9869244f0b0ec1d34dc0c84c613401da07385b7 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,17 +1,24 @@ -from .community.community_detector import CommunityDetector -from .evaluate.length_evaluator import LengthEvaluator -from .evaluate.mtld_evaluator import MTLDEvaluator -from .evaluate.reward_evaluator import RewardEvaluator -from .evaluate.uni_evaluator import UniEvaluator -from .kg_builder.light_rag_kg_builder import LightRAGKGBuilder +from .evaluator import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator +from .generator import ( + AggregatedGenerator, + AtomicGenerator, + CoTGenerator, + MultiHopGenerator, +) +from .kg_builder import LightRAGKGBuilder from .llm.openai_client import OpenAIClient from .llm.topk_token_model import TopkTokenModel +from .partitioner import ( + BFSPartitioner, + DFSPartitioner, + ECEPartitioner, + LeidenPartitioner, +) from .reader import CsvReader, JsonlReader, JsonReader, TxtReader from .search.db.uniprot_search import UniProtSearch from .search.kg.wiki_search import WikiSearch from .search.web.bing_search import BingSearch from .search.web.google_search import GoogleSearch from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter -from .storage.json_storage import JsonKVStorage, JsonListStorage -from .storage.networkx_storage import NetworkXStorage +from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage from .tokenizer import Tokenizer diff --git a/graphgen/models/community/__init__.py b/graphgen/models/community/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/graphgen/models/community/community_detector.py b/graphgen/models/community/community_detector.py deleted file mode 100644 index 0041f4c4a4a57648078ebe650a5c9702d7a17eb5..0000000000000000000000000000000000000000 --- a/graphgen/models/community/community_detector.py +++ /dev/null @@ -1,95 +0,0 @@ -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, Dict, List - -from graphgen.models.storage.networkx_storage import NetworkXStorage - - -@dataclass -class CommunityDetector: - """Class for community detection algorithms.""" - - graph_storage: NetworkXStorage = None - method: str = "leiden" - method_params: Dict[str, Any] = None - - async def detect_communities(self) -> Dict[str, int]: - if self.method == "leiden": - return await self._leiden_communities(**self.method_params or {}) - raise ValueError(f"Unknown community detection method: {self.method}") - - async def get_graph(self): - return await self.graph_storage.get_graph() - - async def _leiden_communities( - self, max_size: int = None, **kwargs - ) -> Dict[str, int]: - """ - Detect communities using the Leiden algorithm. - If max_size is given, any community larger than max_size will be split - into smaller sub-communities each having at most max_size nodes. - """ - import igraph as ig - import networkx as nx - from leidenalg import ModularityVertexPartition, find_partition - - graph = await self.get_graph() - graph.remove_nodes_from(list(nx.isolates(graph))) - - ig_graph = ig.Graph.TupleList(graph.edges(), directed=False) - - random_seed = kwargs.get("random_seed", 42) - use_lcc = kwargs.get("use_lcc", False) - - communities: Dict[str, int] = {} - if use_lcc: - lcc = ig_graph.components().giant() - partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed) - for part, cluster in enumerate(partition): - for v in cluster: - communities[lcc.vs[v]["name"]] = part - else: - offset = 0 - for component in ig_graph.components(): - subgraph = ig_graph.induced_subgraph(component) - partition = find_partition( - subgraph, ModularityVertexPartition, seed=random_seed - ) - for part, cluster in enumerate(partition): - for v in cluster: - original_node = subgraph.vs[v]["name"] - communities[original_node] = part + offset - offset += len(partition) - - # split large communities if max_size is specified - if max_size is None or max_size <= 0: - return communities - - return await self._split_communities(communities, max_size) - - @staticmethod - async def _split_communities( - communities: Dict[str, int], max_size: int - ) -> Dict[str, int]: - """ - Split communities larger than max_size into smaller sub-communities. - """ - cid2nodes: Dict[int, List[str]] = defaultdict(list) - for node, cid in communities.items(): - cid2nodes[cid].append(node) - - new_communities: Dict[str, int] = {} - new_cid = 0 - for cid, nodes in cid2nodes.items(): - if len(nodes) <= max_size: - for n in nodes: - new_communities[n] = new_cid - new_cid += 1 - else: - for start in range(0, len(nodes), max_size): - sub = nodes[start : start + max_size] - for n in sub: - new_communities[n] = new_cid - new_cid += 1 - - return new_communities diff --git a/graphgen/models/evaluate/__init__.py b/graphgen/models/evaluate/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/graphgen/models/evaluator/__init__.py b/graphgen/models/evaluator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b445b4248bdb296c3c487a5f45c55d4f92a1ec --- /dev/null +++ b/graphgen/models/evaluator/__init__.py @@ -0,0 +1,4 @@ +from .length_evaluator import LengthEvaluator +from .mtld_evaluator import MTLDEvaluator +from .reward_evaluator import RewardEvaluator +from .uni_evaluator import UniEvaluator diff --git a/graphgen/models/evaluate/base_evaluator.py b/graphgen/models/evaluator/base_evaluator.py similarity index 100% rename from graphgen/models/evaluate/base_evaluator.py rename to graphgen/models/evaluator/base_evaluator.py diff --git a/graphgen/models/evaluate/length_evaluator.py b/graphgen/models/evaluator/length_evaluator.py similarity index 90% rename from graphgen/models/evaluate/length_evaluator.py rename to graphgen/models/evaluator/length_evaluator.py index 9aa6c7c0afe8d177b82feb4fb85be7756b9690b5..a7e99896d4c690d76bb7f8453a2b89e6995f9e46 100644 --- a/graphgen/models/evaluate/length_evaluator.py +++ b/graphgen/models/evaluator/length_evaluator.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from graphgen.bases.datatypes import QAPair -from graphgen.models.evaluate.base_evaluator import BaseEvaluator +from graphgen.models.evaluator.base_evaluator import BaseEvaluator from graphgen.models.tokenizer import Tokenizer from graphgen.utils import create_event_loop diff --git a/graphgen/models/evaluate/mtld_evaluator.py b/graphgen/models/evaluator/mtld_evaluator.py similarity index 97% rename from graphgen/models/evaluate/mtld_evaluator.py rename to graphgen/models/evaluator/mtld_evaluator.py index fc563d1c2a7e5542261f81d46ee906cbe5655d80..79924fe995819bb0812c4e382646f4b12d5e56b4 100644 --- a/graphgen/models/evaluate/mtld_evaluator.py +++ b/graphgen/models/evaluator/mtld_evaluator.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, field from typing import Set from graphgen.bases.datatypes import QAPair -from graphgen.models.evaluate.base_evaluator import BaseEvaluator +from graphgen.models.evaluator.base_evaluator import BaseEvaluator from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language nltk_helper = NLTKHelper() diff --git a/graphgen/models/evaluate/reward_evaluator.py b/graphgen/models/evaluator/reward_evaluator.py similarity index 100% rename from graphgen/models/evaluate/reward_evaluator.py rename to graphgen/models/evaluator/reward_evaluator.py diff --git a/graphgen/models/evaluate/uni_evaluator.py b/graphgen/models/evaluator/uni_evaluator.py similarity index 100% rename from graphgen/models/evaluate/uni_evaluator.py rename to graphgen/models/evaluator/uni_evaluator.py diff --git a/graphgen/models/generator/__init__.py b/graphgen/models/generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dab300eebd1cb884b0810f6850189a29b68a299a --- /dev/null +++ b/graphgen/models/generator/__init__.py @@ -0,0 +1,4 @@ +from .aggregated_generator import AggregatedGenerator +from .atomic_generator import AtomicGenerator +from .cot_generator import CoTGenerator +from .multi_hop_generator import MultiHopGenerator diff --git a/graphgen/models/generator/aggregated_generator.py b/graphgen/models/generator/aggregated_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..37c54c72c4dae0a6947ce9546fef0d121d3ace3c --- /dev/null +++ b/graphgen/models/generator/aggregated_generator.py @@ -0,0 +1,127 @@ +from dataclasses import dataclass +from typing import Any + +from graphgen.bases import BaseGenerator +from graphgen.templates import AGGREGATED_GENERATION_PROMPT +from graphgen.utils import compute_content_hash, detect_main_language, logger + + +@dataclass +class AggregatedGenerator(BaseGenerator): + """ + Aggregated Generator follows a TWO-STEP process: + 1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning. + The rephrased text is considered as answer to be used in the next step. + 2. question generation: Generate relevant questions based on the rephrased text. + """ + + @staticmethod + def build_prompt( + batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] + ) -> str: + """ + Build prompts for REPHRASE. + :param batch + :return: + """ + nodes, edges = batch + entities_str = "\n".join( + [ + f"{index + 1}. {node[0]}: {node[1]['description']}" + for index, node in enumerate(nodes) + ] + ) + relations_str = "\n".join( + [ + f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" + for index, edge in enumerate(edges) + ] + ) + language = detect_main_language(entities_str + relations_str) + + # TODO: configure add_context + # if add_context: + # original_ids = [ + # node["source_id"].split("")[0] for node in _process_nodes + # ] + [edge[2]["source_id"].split("")[0] for edge in _process_edges] + # original_ids = list(set(original_ids)) + # original_text = await text_chunks_storage.get_by_ids(original_ids) + # original_text = "\n".join( + # [ + # f"{index + 1}. {text['content']}" + # for index, text in enumerate(original_text) + # ] + # ) + prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format( + language=language, entities=entities_str, relationships=relations_str + ) + return prompt + + @staticmethod + def parse_rephrased_text(response: str) -> str: + """ + Parse the rephrased text from the response. + :param response: + :return: rephrased text + """ + if "Rephrased Text:" in response: + rephrased_text = response.split("Rephrased Text:")[1].strip() + elif "重述文本:" in response: + rephrased_text = response.split("重述文本:")[1].strip() + else: + rephrased_text = response.strip() + return rephrased_text.strip('"') + + @staticmethod + def _build_prompt_for_question_generation(answer: str) -> str: + """ + Build prompts for QUESTION GENERATION. + :param answer: + :return: + """ + language = detect_main_language(answer) + prompt = AGGREGATED_GENERATION_PROMPT[language]["QUESTION_GENERATION"].format( + answer=answer + ) + return prompt + + @staticmethod + def parse_response(response: str) -> dict: + if response.startswith("Question:"): + question = response[len("Question:") :].strip() + elif response.startswith("问题:"): + question = response[len("问题:") :].strip() + else: + question = response.strip() + return { + "question": question, + } + + async def generate( + self, + batch: tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] + ], + ) -> dict[str, Any]: + """ + Generate QAs based on a given batch. + :param batch + :return: QA pairs + """ + result = {} + rephrasing_prompt = self.build_prompt(batch) + response = await self.llm_client.generate_answer(rephrasing_prompt) + context = self.parse_rephrased_text(response) + question_generation_prompt = self._build_prompt_for_question_generation(context) + response = await self.llm_client.generate_answer(question_generation_prompt) + question = self.parse_response(response)["question"] + logger.info("Question: %s", question) + logger.info("Answer: %s", context) + qa_pairs = { + compute_content_hash(question): { + "question": question, + "answer": context, + } + } + result.update(qa_pairs) + return result diff --git a/graphgen/models/generator/atomic_generator.py b/graphgen/models/generator/atomic_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..cb566fdf2ae0386cb2438b3a2b9d3cbf45a574ce --- /dev/null +++ b/graphgen/models/generator/atomic_generator.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import Any + +from graphgen.bases import BaseGenerator +from graphgen.templates import ATOMIC_GENERATION_PROMPT +from graphgen.utils import compute_content_hash, detect_main_language, logger + + +@dataclass +class AtomicGenerator(BaseGenerator): + @staticmethod + def build_prompt( + batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] + ) -> str: + nodes, edges = batch + context = "" + for node in nodes: + context += f"- {node[0]}: {node[1]['description']}\n" + for edge in edges: + context += f"- {edge[0]} - {edge[1]}: {edge[2]['description']}\n" + language = detect_main_language(context) + + prompt = ATOMIC_GENERATION_PROMPT[language].format(context=context) + return prompt + + @staticmethod + def parse_response(response: str) -> dict: + """ + AtomicGenerator normally generates one QA pair per response. + So we just need to parse one QA pair from the response. + :param response: + :return: + """ + if "Question:" in response and "Answer:" in response: + question = response.split("Question:")[1].split("Answer:")[0].strip() + answer = response.split("Answer:")[1].strip() + elif "问题:" in response and "答案:" in response: + question = response.split("问题:")[1].split("答案:")[0].strip() + answer = response.split("答案:")[1].strip() + else: + logger.warning("Failed to parse response: %s", response) + return {} + question = question.strip('"') + answer = answer.strip('"') + logger.info("Question: %s", question) + logger.info("Answer: %s", answer) + return { + compute_content_hash(question): { + "question": question, + "answer": answer, + } + } diff --git a/graphgen/models/generator/cot_generator.py b/graphgen/models/generator/cot_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc4fe85156c47010cd0670ae24d7cae7c687943 --- /dev/null +++ b/graphgen/models/generator/cot_generator.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass +from typing import Any + +from graphgen.bases import BaseGenerator +from graphgen.templates import COT_GENERATION_PROMPT +from graphgen.utils import compute_content_hash, detect_main_language, logger + + +@dataclass +class CoTGenerator(BaseGenerator): + @staticmethod + def build_prompt( + batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] + ) -> str: + """ + Build prompts for COT Template Design. + :param batch: + :return: + """ + nodes, edges = batch + entities_str = "\n".join( + [ + f"{index + 1}. {node[0]}: {node[1]['description']}" + for index, node in enumerate(nodes) + ] + ) + relationships_str = "\n".join( + [ + f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" + for index, edge in enumerate(edges) + ] + ) + language = detect_main_language(entities_str + relationships_str) + prompt = COT_GENERATION_PROMPT[language]["COT_TEMPLATE_DESIGN"].format( + entities=entities_str, relationships=relationships_str + ) + return prompt + + @staticmethod + def build_prompt_for_cot_generation( + batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]], + question: str, + reasoning_path: str, + ) -> str: + """ + Build prompts for COT Generation. + """ + nodes, edges = batch + entities_str = "\n".join( + [ + f"{index + 1}. {node[0]}: {node[1]['description']}" + for index, node in enumerate(nodes) + ] + ) + relationships_str = "\n".join( + [ + f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" + for index, edge in enumerate(edges) + ] + ) + language = detect_main_language(entities_str + relationships_str) + prompt = COT_GENERATION_PROMPT[language]["COT_GENERATION"].format( + entities=entities_str, + relationships=relationships_str, + question=question, + reasoning_template=reasoning_path, + ) + return prompt + + @staticmethod + def parse_response(response: str) -> dict: + if "Question:" in response and "Reasoning-Path Design:" in response: + question = ( + response.split("Question:")[1] + .split("Reasoning-Path Design:")[0] + .strip() + ) + reasoning_path = response.split("Reasoning-Path Design:")[1].strip() + elif "问题:" in response and "推理路径设计:" in response: + question = response.split("问题:")[1].split("推理路径设计:")[0].strip() + reasoning_path = response.split("推理路径设计:")[1].strip() + else: + logger.warning("Failed to parse CoT template: %s", response) + return {} + + question = question.strip('"') + reasoning_path = reasoning_path.strip('"') + logger.info("CoT Question: %s", question) + logger.info("CoT Reasoning Path: %s", reasoning_path) + return { + "question": question, + "reasoning_path": reasoning_path, + } + + async def generate( + self, + batch: tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] + ], + ) -> dict[str, Any]: + """ + Generate QAs based on a given batch. + :param batch + :return: QA pairs + """ + result = {} + prompt = self.build_prompt(batch) + response = await self.llm_client.generate_answer(prompt) + response = self.parse_response(response) + question, reasoning_path = response["question"], response["reasoning_path"] + prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path) + cot_answer = await self.llm_client.generate_answer(prompt) + logger.info("CoT Answer: %s", cot_answer) + qa_pairs = { + compute_content_hash(question): { + "question": question, + "answer": cot_answer, + "reasoning_path": reasoning_path, + } + } + result.update(qa_pairs) + return result diff --git a/graphgen/models/generator/multi_hop_generator.py b/graphgen/models/generator/multi_hop_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..257fc1ddeccda97a44e19dd8638236979afa5b47 --- /dev/null +++ b/graphgen/models/generator/multi_hop_generator.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from typing import Any + +from graphgen.bases import BaseGenerator +from graphgen.templates import MULTI_HOP_GENERATION_PROMPT +from graphgen.utils import compute_content_hash, detect_main_language, logger + + +@dataclass +class MultiHopGenerator(BaseGenerator): + @staticmethod + def build_prompt( + batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] + ) -> str: + nodes, edges = batch + entities_str = "\n".join( + [ + f"{index + 1}. {node[0]}: {node[1]['description']}" + for index, node in enumerate(nodes) + ] + ) + + relationships_str = "\n".join( + [ + f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}" + for index, edge in enumerate(edges) + ] + ) + language = detect_main_language(entities_str + relationships_str) + prompt = MULTI_HOP_GENERATION_PROMPT[language].format( + entities=entities_str, relationships=relationships_str + ) + return prompt + + @staticmethod + def parse_response(response: str) -> dict: + if "Question:" in response and "Answer:" in response: + question = response.split("Question:")[1].split("Answer:")[0].strip() + answer = response.split("Answer:")[1].strip() + elif "问题:" in response and "答案:" in response: + question = response.split("问题:")[1].split("答案:")[0].strip() + answer = response.split("答案:")[1].strip() + else: + logger.warning("Failed to parse response: %s", response) + return {} + question = question.strip('"') + answer = answer.strip('"') + logger.info("Question: %s", question) + logger.info("Answer: %s", answer) + return { + compute_content_hash(question): { + "question": question, + "answer": answer, + } + } diff --git a/graphgen/models/kg_builder/__init__.py b/graphgen/models/kg_builder/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..4d630c5f85d41a3d2c1cf3e52384249578f020bb 100644 --- a/graphgen/models/kg_builder/__init__.py +++ b/graphgen/models/kg_builder/__init__.py @@ -0,0 +1 @@ +from .light_rag_kg_builder import LightRAGKGBuilder diff --git a/graphgen/models/llm/limitter.py b/graphgen/models/llm/limitter.py index 01cb1f709f17632652b36a1da0b21e963e823df0..5aee4501fc660675f8fcbc31a3ea89ff43b981e7 100644 --- a/graphgen/models/llm/limitter.py +++ b/graphgen/models/llm/limitter.py @@ -1,17 +1,17 @@ +import asyncio import time from datetime import datetime, timedelta -import asyncio from graphgen.utils import logger class RPM: - def __init__(self, rpm: int = 1000): self.rpm = rpm - self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0} + self.record = {"rpm_slot": self.get_minute_slot(), "counter": 0} - def get_minute_slot(self): + @staticmethod + def get_minute_slot(): current_time = time.time() dt_object = datetime.fromtimestamp(current_time) total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute @@ -22,37 +22,35 @@ class RPM: dt_object = datetime.fromtimestamp(current) minute_slot = self.get_minute_slot() - if self.record['rpm_slot'] == minute_slot: + if self.record["rpm_slot"] == minute_slot: # check RPM exceed - if self.record['counter'] >= self.rpm: + if self.record["counter"] >= self.rpm: # wait until next minute - next_minute = dt_object.replace( - second=0, microsecond=0) + timedelta(minutes=1) + next_minute = dt_object.replace(second=0, microsecond=0) + timedelta( + minutes=1 + ) _next = next_minute.timestamp() sleep_time = abs(_next - current) if not silent: - logger.info('RPM sleep %s', sleep_time) + logger.info("RPM sleep %s", sleep_time) await asyncio.sleep(sleep_time) - self.record = { - 'rpm_slot': self.get_minute_slot(), - 'counter': 0 - } + self.record = {"rpm_slot": self.get_minute_slot(), "counter": 0} else: - self.record = {'rpm_slot': self.get_minute_slot(), 'counter': 0} - self.record['counter'] += 1 + self.record = {"rpm_slot": self.get_minute_slot(), "counter": 0} + self.record["counter"] += 1 if not silent: logger.debug(self.record) class TPM: - def __init__(self, tpm: int = 20000): self.tpm = tpm - self.record = {'tpm_slot': self.get_minute_slot(), 'counter': 0} + self.record = {"tpm_slot": self.get_minute_slot(), "counter": 0} - def get_minute_slot(self): + @staticmethod + def get_minute_slot(): current_time = time.time() dt_object = datetime.fromtimestamp(current_time) total_minutes_since_midnight = dt_object.hour * 60 + dt_object.minute @@ -64,25 +62,25 @@ class TPM: minute_slot = self.get_minute_slot() # get next slot, skip - if self.record['tpm_slot'] != minute_slot: - self.record = {'tpm_slot': minute_slot, 'counter': token_count} + if self.record["tpm_slot"] != minute_slot: + self.record = {"tpm_slot": minute_slot, "counter": token_count} return # check RPM exceed - self.record['counter'] += token_count - if self.record['counter'] > self.tpm: + old_counter = self.record["counter"] + self.record["counter"] += token_count + if self.record["counter"] > self.tpm: + logger.info("Current TPM: %s, limit: %s", old_counter, self.tpm) # wait until next minute - next_minute = dt_object.replace( - second=0, microsecond=0) + timedelta(minutes=1) + next_minute = dt_object.replace(second=0, microsecond=0) + timedelta( + minutes=1 + ) _next = next_minute.timestamp() sleep_time = abs(_next - current) - logger.info('TPM sleep %s', sleep_time) + logger.warning("TPM limit exceeded, wait %s seconds", sleep_time) await asyncio.sleep(sleep_time) - self.record = { - 'tpm_slot': self.get_minute_slot(), - 'counter': token_count - } + self.record = {"tpm_slot": self.get_minute_slot(), "counter": token_count} if not silent: logger.debug(self.record) diff --git a/graphgen/models/llm/openai_client.py b/graphgen/models/llm/openai_client.py index 9f0d276a90f719a4ac38a01078a174a94557512d..30ec39c831f19d9639fe70867126d9cb3af26dca 100644 --- a/graphgen/models/llm/openai_client.py +++ b/graphgen/models/llm/openai_client.py @@ -39,6 +39,8 @@ class OpenAIClient(BaseLLMClient): seed: Optional[int] = None, topk_per_token: int = 5, # number of topk tokens to generate for each token request_limit: bool = False, + rpm: Optional[RPM] = None, + tpm: Optional[TPM] = None, **kwargs: Any, ): super().__init__(**kwargs) @@ -51,8 +53,8 @@ class OpenAIClient(BaseLLMClient): self.token_usage: list = [] self.request_limit = request_limit - self.rpm = RPM(rpm=1000) - self.tpm = TPM(tpm=50000) + self.rpm = rpm or RPM() + self.tpm = tpm or TPM() self.__post_init__() diff --git a/graphgen/models/partitioner/__init__.py b/graphgen/models/partitioner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d37a5d4ab99c2373bc925324a998d2739876b20 --- /dev/null +++ b/graphgen/models/partitioner/__init__.py @@ -0,0 +1,4 @@ +from .bfs_partitioner import BFSPartitioner +from .dfs_partitioner import DFSPartitioner +from .ece_partitioner import ECEPartitioner +from .leiden_partitioner import LeidenPartitioner diff --git a/graphgen/models/partitioner/bfs_partitioner.py b/graphgen/models/partitioner/bfs_partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc51486e8058129ad5f5aacc2c4c8b36fa87ff5 --- /dev/null +++ b/graphgen/models/partitioner/bfs_partitioner.py @@ -0,0 +1,83 @@ +import random +from collections import deque +from dataclasses import dataclass +from typing import Any, List + +from graphgen.bases import BaseGraphStorage, BasePartitioner +from graphgen.bases.datatypes import Community + +NODE_UNIT: str = "n" +EDGE_UNIT: str = "e" + + +@dataclass +class BFSPartitioner(BasePartitioner): + """ + BFS partitioner that partitions the graph into communities of a fixed size. + 1. Randomly choose a unit. + 2. Expand the community using BFS until the max unit size is reached. + (A unit is a node or an edge.) + """ + + async def partition( + self, + g: BaseGraphStorage, + max_units_per_community: int = 1, + **kwargs: Any, + ) -> List[Community]: + nodes = await g.get_all_nodes() + edges = await g.get_all_edges() + + adj, _ = self._build_adjacency_list(nodes, edges) + + used_n: set[str] = set() + used_e: set[frozenset[str]] = set() + communities: List[Community] = [] + + units = [(NODE_UNIT, n[0]) for n in nodes] + [ + (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges + ] + random.shuffle(units) + + for kind, seed in units: + if (kind == NODE_UNIT and seed in used_n) or ( + kind == EDGE_UNIT and seed in used_e + ): + continue + + comm_n: List[str] = [] + comm_e: List[tuple[str, str]] = [] + queue: deque[tuple[str, Any]] = deque([(kind, seed)]) + cnt = 0 + + while queue and cnt < max_units_per_community: + k, it = queue.popleft() + if k == NODE_UNIT: + if it in used_n: + continue + used_n.add(it) + comm_n.append(it) + cnt += 1 + for nei in adj[it]: + e_key = frozenset((it, nei)) + if e_key not in used_e: + queue.append((EDGE_UNIT, e_key)) + else: + if it in used_e: + continue + used_e.add(it) + + u, v = it + comm_e.append((u, v)) + cnt += 1 + # push nodes that are not visited + for n in it: + if n not in used_n: + queue.append((NODE_UNIT, n)) + + if comm_n or comm_e: + communities.append( + Community(id=len(communities), nodes=comm_n, edges=comm_e) + ) + + return communities diff --git a/graphgen/models/partitioner/dfs_partitioner.py b/graphgen/models/partitioner/dfs_partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a64a9b3b6b683a97a0e52a4fc5a8a7988270b4 --- /dev/null +++ b/graphgen/models/partitioner/dfs_partitioner.py @@ -0,0 +1,80 @@ +import random +from dataclasses import dataclass +from typing import Any, List + +from graphgen.bases import BaseGraphStorage, BasePartitioner +from graphgen.bases.datatypes import Community + +NODE_UNIT: str = "n" +EDGE_UNIT: str = "e" + + +@dataclass +class DFSPartitioner(BasePartitioner): + """ + DFS partitioner that partitions the graph into communities of a fixed size. + 1. Randomly choose a unit. + 2. Random walk using DFS until the community reaches the max unit size. + (In GraphGen, a unit is defined as a node or an edge.) + """ + + async def partition( + self, + g: BaseGraphStorage, + max_units_per_community: int = 1, + **kwargs: Any, + ) -> List[Community]: + nodes = await g.get_all_nodes() + edges = await g.get_all_edges() + + adj, _ = self._build_adjacency_list(nodes, edges) + + used_n: set[str] = set() + used_e: set[frozenset[str]] = set() + communities: List[Community] = [] + + units = [(NODE_UNIT, n[0]) for n in nodes] + [ + (EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges + ] + random.shuffle(units) + + for kind, seed in units: + if (kind == NODE_UNIT and seed in used_n) or ( + kind == EDGE_UNIT and seed in used_e + ): + continue + + comm_n, comm_e = [], [] + stack = [(kind, seed)] + cnt = 0 + + while stack and cnt < max_units_per_community: + k, it = stack.pop() + if k == NODE_UNIT: + if it in used_n: + continue + used_n.add(it) + comm_n.append(it) + cnt += 1 + for nei in adj[it]: + e_key = frozenset((it, nei)) + if e_key not in used_e: + stack.append((EDGE_UNIT, e_key)) + break + else: + if it in used_e: + continue + used_e.add(it) + comm_e.append(tuple(it)) + cnt += 1 + # push neighboring nodes + for n in it: + if n not in used_n: + stack.append((NODE_UNIT, n)) + + if comm_n or comm_e: + communities.append( + Community(id=len(communities), nodes=comm_n, edges=comm_e) + ) + + return communities diff --git a/graphgen/models/partitioner/ece_partitioner.py b/graphgen/models/partitioner/ece_partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..fe28d15618cccf631bbdfe2a340a4ff7678fc55f --- /dev/null +++ b/graphgen/models/partitioner/ece_partitioner.py @@ -0,0 +1,163 @@ +import asyncio +import random +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Set, Tuple + +from tqdm.asyncio import tqdm as tqdm_async + +from graphgen.bases import BaseGraphStorage +from graphgen.bases.datatypes import Community +from graphgen.models.partitioner.bfs_partitioner import BFSPartitioner + +NODE_UNIT: str = "n" +EDGE_UNIT: str = "e" + + +@dataclass +class ECEPartitioner(BFSPartitioner): + """ + ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE). + We calculate ECE for edges in KG (represented as 'comprehension loss') + and group edges with similar ECE values into the same community. + 1. Select a sampling strategy. + 2. Choose a unit based on the sampling strategy. + 2. Expand the community using BFS. + 3. When expending, prefer to add units with the sampling strategy. + 4. Stop when the max unit size is reached or the max input length is reached. + (A unit is a node or an edge.) + """ + + @staticmethod + def _sort_units(units: list, edge_sampling: str) -> list: + """ + Sort units with edge sampling strategy + + :param units: total units + :param edge_sampling: edge sampling strategy (random, min_loss, max_loss) + :return: sorted units + """ + if edge_sampling == "random": + random.shuffle(units) + elif edge_sampling == "min_loss": + units = sorted( + units, + key=lambda x: x[-1]["loss"], + ) + elif edge_sampling == "max_loss": + units = sorted( + units, + key=lambda x: x[-1]["loss"], + reverse=True, + ) + else: + raise ValueError(f"Invalid edge sampling: {edge_sampling}") + return units + + async def partition( + self, + g: BaseGraphStorage, + max_units_per_community: int = 10, + min_units_per_community: int = 1, + max_tokens_per_community: int = 10240, + unit_sampling: str = "random", + **kwargs: Any, + ) -> List[Community]: + nodes: List[Tuple[str, dict]] = await g.get_all_nodes() + edges: List[Tuple[str, str, dict]] = await g.get_all_edges() + + adj, _ = self._build_adjacency_list(nodes, edges) + node_dict = dict(nodes) + edge_dict = {frozenset((u, v)): d for u, v, d in edges} + + all_units: List[Tuple[str, Any, dict]] = [ + (NODE_UNIT, nid, d) for nid, d in nodes + ] + [(EDGE_UNIT, frozenset((u, v)), d) for u, v, d in edges] + + used_n: Set[str] = set() + used_e: Set[frozenset[str]] = set() + communities: List = [] + + all_units = self._sort_units(all_units, unit_sampling) + + async def _grow_community( + seed_unit: Tuple[str, Any, dict] + ) -> Optional[Community]: + nonlocal used_n, used_e + + community_nodes: Dict[str, dict] = {} + community_edges: Dict[frozenset[str], dict] = {} + queue: asyncio.Queue = asyncio.Queue() + token_sum = 0 + + async def _add_unit(u): + nonlocal token_sum + t, i, d = u + if t == NODE_UNIT: # node + if i in used_n or i in community_nodes: + return False + community_nodes[i] = d + used_n.add(i) + else: # edge + if i in used_e or i in community_edges: + return False + community_edges[i] = d + used_e.add(i) + token_sum += d.get("length", 0) + return True + + await _add_unit(seed_unit) + await queue.put(seed_unit) + + # BFS + while not queue.empty(): + if ( + len(community_nodes) + len(community_edges) + >= max_units_per_community + or token_sum >= max_tokens_per_community + ): + break + + cur_type, cur_id, _ = await queue.get() + + neighbors: List[Tuple[str, Any, dict]] = [] + if cur_type == NODE_UNIT: + for nb_id in adj.get(cur_id, []): + e_key = frozenset((cur_id, nb_id)) + if e_key not in used_e and e_key not in community_edges: + neighbors.append((EDGE_UNIT, e_key, edge_dict[e_key])) + else: + for n_id in cur_id: + if n_id not in used_n and n_id not in community_nodes: + neighbors.append((NODE_UNIT, n_id, node_dict[n_id])) + + neighbors = self._sort_units(neighbors, unit_sampling) + for nb in neighbors: + if ( + len(community_nodes) + len(community_edges) + >= max_units_per_community + or token_sum >= max_tokens_per_community + ): + break + if await _add_unit(nb): + await queue.put(nb) + + if len(community_nodes) + len(community_edges) < min_units_per_community: + return None + + return Community( + id=len(communities), + nodes=list(community_nodes.keys()), + edges=[(u, v) for (u, v), _ in community_edges.items()], + ) + + async for unit in tqdm_async(all_units, desc="ECE partition"): + utype, uid, _ = unit + if (utype == NODE_UNIT and uid in used_n) or ( + utype == EDGE_UNIT and uid in used_e + ): + continue + comm = await _grow_community(unit) + if comm is not None: + communities.append(comm) + + return communities diff --git a/graphgen/models/partitioner/leiden_partitioner.py b/graphgen/models/partitioner/leiden_partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa38ae99b4b981d38a9b00ff0b9079beb12afc1 --- /dev/null +++ b/graphgen/models/partitioner/leiden_partitioner.py @@ -0,0 +1,120 @@ +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Set, Tuple + +import igraph as ig +from leidenalg import ModularityVertexPartition, find_partition + +from graphgen.bases import BaseGraphStorage, BasePartitioner +from graphgen.bases.datatypes import Community + + +@dataclass +class LeidenPartitioner(BasePartitioner): + """ + Leiden partitioner that partitions the graph into communities using the Leiden algorithm. + """ + + async def partition( + self, + g: BaseGraphStorage, + max_size: int = 20, + use_lcc: bool = False, + random_seed: int = 42, + **kwargs: Any, + ) -> List[Community]: + """ + Leiden Partition follows these steps: + 1. export the graph from graph storage + 2. use the leiden algorithm to detect communities, get {node: community_id} + 3. split large communities if max_size is given + 4. convert {node: community_id} to List[Community] + :param g + :param max_size: maximum size of each community, if None or <=0, no limit + :param use_lcc: whether to use the largest connected component only + :param random_seed + :param kwargs: other parameters for the leiden algorithm + :return: + """ + nodes = await g.get_all_nodes() # List[Tuple[str, dict]] + edges = await g.get_all_edges() # List[Tuple[str, str, dict]] + + node2cid: Dict[str, int] = await self._run_leiden( + nodes, edges, use_lcc, random_seed + ) + + if max_size is not None and max_size > 0: + node2cid = await self._split_communities(node2cid, max_size) + + cid2nodes: Dict[int, List[str]] = defaultdict(list) + for n, cid in node2cid.items(): + cid2nodes[cid].append(n) + + communities: List[Community] = [] + for cid, nodes in cid2nodes.items(): + node_set: Set[str] = set(nodes) + comm_edges: List[Tuple[str, str]] = [ + (u, v) for u, v, _ in edges if u in node_set and v in node_set + ] + communities.append(Community(id=cid, nodes=nodes, edges=comm_edges)) + return communities + + @staticmethod + async def _run_leiden( + nodes: List[Tuple[str, dict]], + edges: List[Tuple[str, str, dict]], + use_lcc: bool = False, + random_seed: int = 42, + ) -> Dict[str, int]: + # build igraph + ig_graph = ig.Graph.TupleList(((u, v) for u, v, _ in edges), directed=False) + + # remove isolated nodes + ig_graph.delete_vertices(ig_graph.vs.select(_degree_eq=0)) + + node2cid: Dict[str, int] = {} + if use_lcc: + lcc = ig_graph.components().giant() + partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed) + for part_id, cluster in enumerate(partition): + for v in cluster: + node2cid[lcc.vs[v]["name"]] = part_id + else: + offset = 0 + for component in ig_graph.components(): + subgraph = ig_graph.induced_subgraph(component) + partition = find_partition( + subgraph, ModularityVertexPartition, seed=random_seed + ) + for part_id, cluster in enumerate(partition): + for v in cluster: + original_node = subgraph.vs[v]["name"] + node2cid[original_node] = part_id + offset + offset += len(partition) + return node2cid + + @staticmethod + async def _split_communities( + node2cid: Dict[str, int], max_size: int + ) -> Dict[str, int]: + """ + Split communities larger than max_size into smaller sub-communities. + """ + cid2nodes: Dict[int, List[str]] = defaultdict(list) + for n, cid in node2cid.items(): + cid2nodes[cid].append(n) + + new_mapping: Dict[str, int] = {} + new_cid = 0 + for nodes in cid2nodes.values(): + if len(nodes) <= max_size: + for n in nodes: + new_mapping[n] = new_cid + new_cid += 1 + else: + for start in range(0, len(nodes), max_size): + chunk = nodes[start : start + max_size] + for n in chunk: + new_mapping[n] = new_cid + new_cid += 1 + return new_mapping diff --git a/graphgen/models/storage/__init__.py b/graphgen/models/storage/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..56338984f199029eb065e2dbbcaad0fa767f09fb 100644 --- a/graphgen/models/storage/__init__.py +++ b/graphgen/models/storage/__init__.py @@ -0,0 +1,2 @@ +from .json_storage import JsonKVStorage, JsonListStorage +from .networkx_storage import NetworkXStorage diff --git a/graphgen/models/storage/networkx_storage.py b/graphgen/models/storage/networkx_storage.py index 28baebdae7106a7e596b9fe24eb18b9a001918a7..539ab8421550c3bab446cc3858abd7c7e6a44dbc 100644 --- a/graphgen/models/storage/networkx_storage.py +++ b/graphgen/models/storage/networkx_storage.py @@ -102,8 +102,8 @@ class NetworkXStorage(BaseGraphStorage): async def get_node(self, node_id: str) -> Union[dict, None]: return self._graph.nodes.get(node_id) - async def get_all_nodes(self) -> Union[list[dict], None]: - return self._graph.nodes(data=True) + async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]: + return list(self._graph.nodes(data=True)) async def node_degree(self, node_id: str) -> int: return self._graph.degree(node_id) @@ -116,8 +116,8 @@ class NetworkXStorage(BaseGraphStorage): ) -> Union[dict, None]: return self._graph.edges.get((source_node_id, target_node_id)) - async def get_all_edges(self) -> Union[list[dict], None]: - return self._graph.edges(data=True) + async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]: + return list(self._graph.edges(data=True)) async def get_node_edges( self, source_node_id: str diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 11a78972d3b2ed2c35ba218bfbfd83908d561682..88c314977b1156a6d14cec9fb9d5e1107203d065 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,13 +1,8 @@ -from graphgen.operators.build_kg.build_kg import build_kg -from graphgen.operators.generate.generate_cot import generate_cot -from graphgen.operators.search.search_all import search_all - +from .build_kg import build_kg +from .generate import generate_qas from .judge import judge_statement +from .partition import partition_kg from .quiz import quiz from .read import read_files +from .search import search_all from .split import chunk_documents -from .traverse_graph import ( - traverse_graph_for_aggregated, - traverse_graph_for_atomic, - traverse_graph_for_multi_hop, -) diff --git a/graphgen/operators/build_kg/__init__.py b/graphgen/operators/build_kg/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..18766fe6311faf02e6d3fd40f8cc2ef34bb2fb24 100644 --- a/graphgen/operators/build_kg/__init__.py +++ b/graphgen/operators/build_kg/__init__.py @@ -0,0 +1 @@ +from .build_kg import build_kg diff --git a/graphgen/operators/build_kg/split_kg.py b/graphgen/operators/build_kg/split_kg.py deleted file mode 100644 index 6033bc85d35df78fd0e568bce6b6a9c915e7520f..0000000000000000000000000000000000000000 --- a/graphgen/operators/build_kg/split_kg.py +++ /dev/null @@ -1,382 +0,0 @@ -import random -from collections import defaultdict -from typing import Dict - -from tqdm.asyncio import tqdm as tqdm_async - -from graphgen.models import NetworkXStorage -from graphgen.utils import logger - - -async def _get_node_info( - node_id: str, - graph_storage: NetworkXStorage, -) -> dict: - """ - Get node info - - :param node_id: node id - :param graph_storage: graph storage instance - :return: node info - """ - node_data = await graph_storage.get_node(node_id) - return {"node_id": node_id, **node_data} - - -def _get_level_n_edges_by_max_width( - edge_adj_list: dict, - node_dict: dict, - edges: list, - nodes, - src_edge: tuple, - max_depth: int, - bidirectional: bool, - max_extra_edges: int, - edge_sampling: str, - loss_strategy: str = "only_edge", -) -> list: - """ - Get level n edges for an edge. - n is decided by max_depth in traverse_strategy - - :param edge_adj_list - :param node_dict - :param edges - :param nodes - :param src_edge - :param max_depth - :param bidirectional - :param max_extra_edges - :param edge_sampling - :return: level n edges - """ - src_id, tgt_id, _ = src_edge - - level_n_edges = [] - - start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id} - - while max_depth > 0 and max_extra_edges > 0: - max_depth -= 1 - - candidate_edges = [ - edges[edge_id] - for node in start_nodes - for edge_id in edge_adj_list[node] - if not edges[edge_id][2].get("visited", False) - ] - - if not candidate_edges: - break - - if len(candidate_edges) >= max_extra_edges: - if loss_strategy == "both": - er_tuples = [ - ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) - for edge in candidate_edges - ] - candidate_edges = _sort_tuples(er_tuples, edge_sampling)[ - :max_extra_edges - ] - elif loss_strategy == "only_edge": - candidate_edges = _sort_edges(candidate_edges, edge_sampling)[ - :max_extra_edges - ] - else: - raise ValueError(f"Invalid loss strategy: {loss_strategy}") - - for edge in candidate_edges: - level_n_edges.append(edge) - edge[2]["visited"] = True - break - - max_extra_edges -= len(candidate_edges) - new_start_nodes = set() - - for edge in candidate_edges: - level_n_edges.append(edge) - edge[2]["visited"] = True - - if not edge[0] in start_nodes: - new_start_nodes.add(edge[0]) - if not edge[1] in start_nodes: - new_start_nodes.add(edge[1]) - - start_nodes = new_start_nodes - - return level_n_edges - - -def _get_level_n_edges_by_max_tokens( - edge_adj_list: dict, - node_dict: dict, - edges: list, - nodes: list, - src_edge: tuple, - max_depth: int, - bidirectional: bool, - max_tokens: int, - edge_sampling: str, - loss_strategy: str = "only_edge", -) -> list: - """ - Get level n edges for an edge. - n is decided by max_depth in traverse_strategy. - - :param edge_adj_list - :param node_dict - :param edges - :param nodes - :param src_edge - :param max_depth - :param bidirectional - :param max_tokens - :param edge_sampling - :return: level n edges - """ - src_id, tgt_id, src_edge_data = src_edge - - max_tokens -= ( - src_edge_data["length"] - + nodes[node_dict[src_id]][1]["length"] - + nodes[node_dict[tgt_id]][1]["length"] - ) - - level_n_edges = [] - - start_nodes = {tgt_id} if not bidirectional else {src_id, tgt_id} - temp_nodes = {src_id, tgt_id} - - while max_depth > 0 and max_tokens > 0: - max_depth -= 1 - - candidate_edges = [ - edges[edge_id] - for node in start_nodes - for edge_id in edge_adj_list[node] - if not edges[edge_id][2].get("visited", False) - ] - - if not candidate_edges: - break - - if loss_strategy == "both": - er_tuples = [ - ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) - for edge in candidate_edges - ] - candidate_edges = _sort_tuples(er_tuples, edge_sampling) - elif loss_strategy == "only_edge": - candidate_edges = _sort_edges(candidate_edges, edge_sampling) - else: - raise ValueError(f"Invalid loss strategy: {loss_strategy}") - - for edge in candidate_edges: - max_tokens -= edge[2]["length"] - if not edge[0] in temp_nodes: - max_tokens -= nodes[node_dict[edge[0]]][1]["length"] - if not edge[1] in temp_nodes: - max_tokens -= nodes[node_dict[edge[1]]][1]["length"] - - if max_tokens < 0: - return level_n_edges - - level_n_edges.append(edge) - edge[2]["visited"] = True - temp_nodes.add(edge[0]) - temp_nodes.add(edge[1]) - - new_start_nodes = set() - for edge in candidate_edges: - if not edge[0] in start_nodes: - new_start_nodes.add(edge[0]) - if not edge[1] in start_nodes: - new_start_nodes.add(edge[1]) - - start_nodes = new_start_nodes - - return level_n_edges - - -def _sort_tuples(er_tuples: list, edge_sampling: str) -> list: - """ - Sort edges with edge sampling strategy - - :param er_tuples: [(nodes:list, edge:tuple)] - :param edge_sampling: edge sampling strategy (random, min_loss, max_loss) - :return: sorted edges - """ - if edge_sampling == "random": - er_tuples = random.sample(er_tuples, len(er_tuples)) - elif edge_sampling == "min_loss": - er_tuples = sorted( - er_tuples, - key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"], - ) - elif edge_sampling == "max_loss": - er_tuples = sorted( - er_tuples, - key=lambda x: sum(node[1]["loss"] for node in x[0]) + x[1][2]["loss"], - reverse=True, - ) - else: - raise ValueError(f"Invalid edge sampling: {edge_sampling}") - edges = [edge for _, edge in er_tuples] - return edges - - -def _sort_edges(edges: list, edge_sampling: str) -> list: - """ - Sort edges with edge sampling strategy - - :param edges: total edges - :param edge_sampling: edge sampling strategy (random, min_loss, max_loss) - :return: sorted edges - """ - if edge_sampling == "random": - random.shuffle(edges) - elif edge_sampling == "min_loss": - edges = sorted(edges, key=lambda x: x[2]["loss"]) - elif edge_sampling == "max_loss": - edges = sorted(edges, key=lambda x: x[2]["loss"], reverse=True) - else: - raise ValueError(f"Invalid edge sampling: {edge_sampling}") - return edges - - -async def get_batches_with_strategy( # pylint: disable=too-many-branches - nodes: list, - edges: list, - graph_storage: NetworkXStorage, - traverse_strategy: Dict, -): - expand_method = traverse_strategy["expand_method"] - if expand_method == "max_width": - logger.info("Using max width strategy") - elif expand_method == "max_tokens": - logger.info("Using max tokens strategy") - else: - raise ValueError(f"Invalid expand method: {expand_method}") - - max_depth = traverse_strategy["max_depth"] - edge_sampling = traverse_strategy["edge_sampling"] - - # 构建临接矩阵 - edge_adj_list = defaultdict(list) - node_dict = {} - processing_batches = [] - - node_cache = {} - - async def get_cached_node_info(node_id: str) -> dict: - if node_id not in node_cache: - node_cache[node_id] = await _get_node_info(node_id, graph_storage) - return node_cache[node_id] - - for i, (node_name, _) in enumerate(nodes): - node_dict[node_name] = i - - if traverse_strategy["loss_strategy"] == "both": - er_tuples = [ - ([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) - for edge in edges - ] - edges = _sort_tuples(er_tuples, edge_sampling) - elif traverse_strategy["loss_strategy"] == "only_edge": - edges = _sort_edges(edges, edge_sampling) - else: - raise ValueError(f"Invalid loss strategy: {traverse_strategy['loss_strategy']}") - - for i, (src, tgt, _) in enumerate(edges): - edge_adj_list[src].append(i) - edge_adj_list[tgt].append(i) - - for edge in tqdm_async(edges, desc="Preparing batches"): - if "visited" in edge[2] and edge[2]["visited"]: - continue - - edge[2]["visited"] = True - - _process_nodes = [] - _process_edges = [] - - src_id = edge[0] - tgt_id = edge[1] - - _process_nodes.extend( - [await get_cached_node_info(src_id), await get_cached_node_info(tgt_id)] - ) - _process_edges.append(edge) - - if expand_method == "max_width": - level_n_edges = _get_level_n_edges_by_max_width( - edge_adj_list, - node_dict, - edges, - nodes, - edge, - max_depth, - traverse_strategy["bidirectional"], - traverse_strategy["max_extra_edges"], - edge_sampling, - traverse_strategy["loss_strategy"], - ) - else: - level_n_edges = _get_level_n_edges_by_max_tokens( - edge_adj_list, - node_dict, - edges, - nodes, - edge, - max_depth, - traverse_strategy["bidirectional"], - traverse_strategy["max_tokens"], - edge_sampling, - traverse_strategy["loss_strategy"], - ) - - for _edge in level_n_edges: - _process_nodes.append(await get_cached_node_info(_edge[0])) - _process_nodes.append(await get_cached_node_info(_edge[1])) - _process_edges.append(_edge) - - # 去重 - _process_nodes = list( - {node["node_id"]: node for node in _process_nodes}.values() - ) - _process_edges = list( - {(edge[0], edge[1]): edge for edge in _process_edges}.values() - ) - - processing_batches.append((_process_nodes, _process_edges)) - - logger.info("Processing batches: %d", len(processing_batches)) - - # isolate nodes - isolated_node_strategy = traverse_strategy["isolated_node_strategy"] - if isolated_node_strategy == "add": - processing_batches = await _add_isolated_nodes( - nodes, processing_batches, graph_storage - ) - logger.info( - "Processing batches after adding isolated nodes: %d", - len(processing_batches), - ) - - return processing_batches - - -async def _add_isolated_nodes( - nodes: list, - processing_batches: list, - graph_storage: NetworkXStorage, -) -> list: - visited_nodes = set() - for _process_nodes, _process_edges in processing_batches: - for node in _process_nodes: - visited_nodes.add(node["node_id"]) - for node in nodes: - if node[0] not in visited_nodes: - _process_nodes = [await _get_node_info(node[0], graph_storage)] - processing_batches.append((_process_nodes, [])) - return processing_batches diff --git a/graphgen/operators/generate/__init__.py b/graphgen/operators/generate/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..035eca3608197138d5c9f1ac1614792bd4a8754b 100644 --- a/graphgen/operators/generate/__init__.py +++ b/graphgen/operators/generate/__init__.py @@ -0,0 +1 @@ +from .generate_qas import generate_qas diff --git a/graphgen/operators/generate/generate_cot.py b/graphgen/operators/generate/generate_cot.py deleted file mode 100644 index e96635ac29c4d85b471bb9e4e7816c1c5252a146..0000000000000000000000000000000000000000 --- a/graphgen/operators/generate/generate_cot.py +++ /dev/null @@ -1,117 +0,0 @@ -import asyncio -from typing import Dict, List, Tuple - -from tqdm.asyncio import tqdm as tqdm_async - -from graphgen.models import CommunityDetector, NetworkXStorage, OpenAIClient -from graphgen.templates import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT -from graphgen.utils import compute_content_hash, detect_main_language - - -async def generate_cot( - graph_storage: NetworkXStorage, - synthesizer_llm_client: OpenAIClient, - method_params: Dict = None, -): - method = method_params.get("method", "leiden") - detector = CommunityDetector( - graph_storage=graph_storage, method=method, method_params=method_params - ) - - results = await detector.detect_communities() - - # Convert results to a format suitable for summarization - communities = {} - for node, community_id in results.items(): - if community_id not in communities: - communities[community_id] = [] - communities[community_id].append(node) - - if not communities: - return {} - - semaphore = asyncio.Semaphore(value=1000) - - async def _generate_from_single_community( - c_id: int, nodes: List[str] - ) -> Tuple[int, Tuple[str, str, str]]: - """Summarize a single community.""" - async with semaphore: - entities: List[str] = [] - relationships: List[str] = [] - - for n in nodes: - node_data = await graph_storage.get_node(n) - if node_data is not None: - entities.append(f"({n}: {node_data.get('description')})") - - edges = await graph_storage.get_node_edges(n) - for edge in edges: - target = edge[1] - if target in nodes: - edge_data = await graph_storage.get_edge(n, target) - relationships.append( - f"({n}) - [{edge_data['description']}] -> ({target})" - ) - - entities_str = "\n".join(entities) - relationships_str = "\n".join(relationships) - - language = ( - "English" - if detect_main_language(entities_str + relationships_str) == "en" - else "Chinese" - ) - - prompt = COT_TEMPLATE_DESIGN_PROMPT[language]["TEMPLATE"].format( - entities=entities_str, - relationships=relationships_str, - ) - - cot_template = await synthesizer_llm_client.generate_answer(prompt) - - if "问题:" in cot_template and "推理路径设计:" in cot_template: - question = cot_template.split("问题:")[1].split("推理路径设计:")[0].strip() - reasoning_path = cot_template.split("推理路径设计:")[1].strip() - elif ( - "Question:" in cot_template and "Reasoning-Path Design:" in cot_template - ): - question = ( - cot_template.split("Question:")[1] - .split("Reasoning-Path Design:")[0] - .strip() - ) - reasoning_path = cot_template.split("Reasoning-Path Design:")[1].strip() - else: - raise ValueError("COT template format is incorrect.") - - prompt = COT_GENERATION_PROMPT[language]["TEMPLATE"].format( - entities=entities_str, - relationships=relationships_str, - question=question, - reasoning_template=reasoning_path, - ) - - cot_answer = await synthesizer_llm_client.generate_answer(prompt) - - return c_id, (question, reasoning_path, cot_answer) - - cid_nodes = list(communities.items()) - - results: Dict = {} - async for coro in tqdm_async( - asyncio.as_completed( - [_generate_from_single_community(cid, nodes) for cid, nodes in cid_nodes] - ), - total=len(cid_nodes), - desc="[Generating COT] Generating CoT data from communities", - unit="community", - ): - cid, (q, r, a) = await coro - results[compute_content_hash(q)] = { - "question": q, - "reasoning_path": r, - "answer": a, - } - - return results diff --git a/graphgen/operators/generate/generate_qas.py b/graphgen/operators/generate/generate_qas.py new file mode 100644 index 0000000000000000000000000000000000000000..363bb7e4308f89109a9b5b34332dc4cc600bfd41 --- /dev/null +++ b/graphgen/operators/generate/generate_qas.py @@ -0,0 +1,58 @@ +from typing import Any + +from graphgen.bases import BaseLLMClient +from graphgen.models import ( + AggregatedGenerator, + AtomicGenerator, + CoTGenerator, + MultiHopGenerator, +) +from graphgen.utils import logger, run_concurrent + + +async def generate_qas( + llm_client: BaseLLMClient, + batches: list[ + tuple[ + list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] + ] + ], + generation_config: dict, +) -> list[dict[str, Any]]: + """ + Generate question-answer pairs based on nodes and edges. + :param llm_client: LLM client + :param batches + :param generation_config + :return: QA pairs + """ + mode = generation_config["mode"] + logger.info("[Generation] mode: %s, batches: %d", mode, len(batches)) + + if mode == "atomic": + generator = AtomicGenerator(llm_client) + elif mode == "aggregated": + generator = AggregatedGenerator(llm_client) + elif mode == "multi_hop": + generator = MultiHopGenerator(llm_client) + elif mode == "cot": + generator = CoTGenerator(llm_client) + else: + raise ValueError(f"Unsupported generation mode: {mode}") + + results = await run_concurrent( + generator.generate, + batches, + desc="[4/4]Generating QAs", + unit="batch", + ) + + # format + data_format = generation_config["data_format"] + logger.info("Output data format: %s", data_format) + + results = generator.format_generation_results( + results, output_data_format=data_format + ) + + return results diff --git a/graphgen/operators/partition/__init__.py b/graphgen/operators/partition/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..21f934b38b328b9fb804c5e6969bd85662c99a0e --- /dev/null +++ b/graphgen/operators/partition/__init__.py @@ -0,0 +1 @@ +from .partition_kg import partition_kg diff --git a/graphgen/operators/partition/partition_kg.py b/graphgen/operators/partition/partition_kg.py new file mode 100644 index 0000000000000000000000000000000000000000..b03d3221329397efabf1bdce3a41bf4c7d044ed5 --- /dev/null +++ b/graphgen/operators/partition/partition_kg.py @@ -0,0 +1,48 @@ +from typing import Any + +from graphgen.bases import BaseGraphStorage, BaseTokenizer +from graphgen.models import ( + BFSPartitioner, + DFSPartitioner, + ECEPartitioner, + LeidenPartitioner, +) +from graphgen.utils import logger + +from .pre_tokenize import pre_tokenize + + +async def partition_kg( + kg_instance: BaseGraphStorage, + tokenizer: Any = BaseTokenizer, + partition_config: dict = None, +) -> list[ + tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]] +]: + method = partition_config["method"] + method_params = partition_config["method_params"] + if method == "bfs": + logger.info("Partitioning knowledge graph using BFS method.") + partitioner = BFSPartitioner() + elif method == "dfs": + logger.info("Partitioning knowledge graph using DFS method.") + partitioner = DFSPartitioner() + elif method == "ece": + logger.info("Partitioning knowledge graph using ECE method.") + # TODO: before ECE partitioning, we need to: + # 1. 'quiz and judge' to get the comprehension loss if unit_sampling is not random + # 2. pre-tokenize nodes and edges to get the token length + edges = await kg_instance.get_all_edges() + nodes = await kg_instance.get_all_nodes() + await pre_tokenize(kg_instance, tokenizer, edges, nodes) + partitioner = ECEPartitioner() + elif method == "leiden": + logger.info("Partitioning knowledge graph using Leiden method.") + partitioner = LeidenPartitioner() + else: + raise ValueError(f"Unsupported partition method: {method}") + + communities = await partitioner.partition(g=kg_instance, **method_params) + logger.info("Partitioned the graph into %d communities.", len(communities)) + batches = await partitioner.community2batch(communities, g=kg_instance) + return batches diff --git a/graphgen/operators/partition/pre_tokenize.py b/graphgen/operators/partition/pre_tokenize.py new file mode 100644 index 0000000000000000000000000000000000000000..e1b45e39989dc02eed304323dd400bb4bbabde79 --- /dev/null +++ b/graphgen/operators/partition/pre_tokenize.py @@ -0,0 +1,47 @@ +import asyncio +from typing import List, Tuple + +from graphgen.bases import BaseGraphStorage, BaseTokenizer +from graphgen.utils import run_concurrent + + +async def pre_tokenize( + graph_storage: BaseGraphStorage, + tokenizer: BaseTokenizer, + edges: List[Tuple], + nodes: List[Tuple], +) -> Tuple[List, List]: + """为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。""" + sem = asyncio.Semaphore(1000) + + async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple: + async with sem: + data = obj[1] if is_node else obj[2] + if "length" not in data: + loop = asyncio.get_event_loop() + data["length"] = len( + await loop.run_in_executor( + None, tokenizer.encode, data["description"] + ) + ) + if is_node: + await graph_storage.update_node(obj[0], obj[1]) + else: + await graph_storage.update_edge(obj[0], obj[1], obj[2]) + return obj + + new_edges, new_nodes = await asyncio.gather( + run_concurrent( + lambda e: _patch_and_write(e, is_node=False), + edges, + desc="Pre-tokenizing edges", + ), + run_concurrent( + lambda n: _patch_and_write(n, is_node=True), + nodes, + desc="Pre-tokenizing nodes", + ), + ) + + await graph_storage.index_done_callback() + return new_edges, new_nodes diff --git a/graphgen/operators/search/__init__.py b/graphgen/operators/search/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..3d90f12a1e9e2e62c7b517b8246844bfbfe16bb3 100644 --- a/graphgen/operators/search/__init__.py +++ b/graphgen/operators/search/__init__.py @@ -0,0 +1 @@ +from .search_all import search_all diff --git a/graphgen/operators/traverse_graph.py b/graphgen/operators/traverse_graph.py deleted file mode 100644 index dff63b0b32d664f745e195db79a03e44f2e0c50f..0000000000000000000000000000000000000000 --- a/graphgen/operators/traverse_graph.py +++ /dev/null @@ -1,540 +0,0 @@ -import asyncio -from typing import Dict - -import gradio as gr -from tqdm.asyncio import tqdm as tqdm_async - -from graphgen.models import JsonKVStorage, NetworkXStorage, OpenAIClient, Tokenizer -from graphgen.operators.build_kg.split_kg import get_batches_with_strategy -from graphgen.templates import ( - ANSWER_REPHRASING_PROMPT, - MULTI_HOP_GENERATION_PROMPT, - QUESTION_GENERATION_PROMPT, -) -from graphgen.utils import compute_content_hash, detect_main_language, logger - - -async def _pre_tokenize( - graph_storage: NetworkXStorage, tokenizer: Tokenizer, edges: list, nodes: list -) -> tuple: - - sem = asyncio.Semaphore(1000) - - async def handle_edge(edge: tuple) -> tuple: - async with sem: - if "length" not in edge[2]: - edge[2]["length"] = len( - await asyncio.get_event_loop().run_in_executor( - None, tokenizer.encode, edge[2]["description"] - ) - ) - return edge - - async def handle_node(node: dict) -> dict: - async with sem: - if "length" not in node[1]: - node[1]["length"] = len( - await asyncio.get_event_loop().run_in_executor( - None, tokenizer.encode, node[1]["description"] - ) - ) - return node - - new_edges = [] - new_nodes = [] - - for result in tqdm_async( - asyncio.as_completed([handle_edge(edge) for edge in edges]), - total=len(edges), - desc="Pre-tokenizing edges", - ): - new_edge = await result - await graph_storage.update_edge(new_edge[0], new_edge[1], new_edge[2]) - new_edges.append(new_edge) - - for result in tqdm_async( - asyncio.as_completed([handle_node(node) for node in nodes]), - total=len(nodes), - desc="Pre-tokenizing nodes", - ): - new_node = await result - await graph_storage.update_node(new_node[0], new_node[1]) - new_nodes.append(new_node) - - await graph_storage.index_done_callback() - return new_edges, new_nodes - - -async def _construct_rephrasing_prompt( - _process_nodes: list, - _process_edges: list, - text_chunks_storage: JsonKVStorage, - add_context: bool = False, -) -> str: - entities = [ - f"{_process_node['node_id']}: {_process_node['description']}" - for _process_node in _process_nodes - ] - relations = [ - f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}" - for _process_edge in _process_edges - ] - - entities_str = "\n".join( - [f"{index + 1}. {entity}" for index, entity in enumerate(entities)] - ) - relations_str = "\n".join( - [f"{index + 1}. {relation}" for index, relation in enumerate(relations)] - ) - language = ( - "Chinese" - if detect_main_language(entities_str + relations_str) == "zh" - else "English" - ) - - if add_context: - original_ids = [ - node["source_id"].split("")[0] for node in _process_nodes - ] + [edge[2]["source_id"].split("")[0] for edge in _process_edges] - - original_ids = list(set(original_ids)) - original_text = await text_chunks_storage.get_by_ids(original_ids) - original_text = "\n".join( - [ - f"{index + 1}. {text['content']}" - for index, text in enumerate(original_text) - ] - ) - - prompt = ANSWER_REPHRASING_PROMPT[language]["CONTEXT_TEMPLATE"].format( - language=language, - original_text=original_text, - entities=entities_str, - relationships=relations_str, - ) - return prompt - - prompt = ANSWER_REPHRASING_PROMPT[language]["TEMPLATE"].format( - language=language, entities=entities_str, relationships=relations_str - ) - return prompt - - -def get_average_loss(batch: tuple, loss_strategy: str) -> float: - try: - if loss_strategy == "only_edge": - return sum(edge[2]["loss"] for edge in batch[1]) / len(batch[1]) - if loss_strategy == "both": - return sum(edge[2]["loss"] for edge in batch[1]) + sum( - node["loss"] for node in batch[0] - ) / (len(batch[0]) + len(batch[1])) - raise ValueError("Invalid loss strategy") - except Exception as e: # pylint: disable=broad-except - logger.warning( - "Loss not found in some nodes or edges, setting loss to -1.0: %s", e - ) - return -1.0 - - -def _post_process_synthetic_data(data): - block = data.split("\n\n") - qas = [] - for line in block: - if "Question:" in line and "Answer:" in line: - question = line.split("Question:")[1].split("Answer:")[0].strip() - answer = line.split("Answer:")[1].strip() - qas.append({"question": question, "answer": answer}) - elif "问题:" in line and "答案:" in line: - question = line.split("问题:")[1].split("答案:")[0].strip() - answer = line.split("答案:")[1].strip() - qas.append({"question": question, "answer": answer}) - elif "问题:" in line and "回答:" in line: - question = line.split("问题:")[1].split("回答:")[0].strip() - answer = line.split("回答:")[1].strip() - qas.append({"question": question, "answer": answer}) - return qas - - -async def traverse_graph_for_aggregated( - llm_client: OpenAIClient, - tokenizer: Tokenizer, - graph_storage: NetworkXStorage, - traverse_strategy: Dict, - text_chunks_storage: JsonKVStorage, - progress_bar: gr.Progress = None, - max_concurrent: int = 1000, -) -> dict: - """ - Traverse the graph - - :param llm_client - :param tokenizer - :param graph_storage - :param traverse_strategy - :param text_chunks_storage - :param progress_bar - :param max_concurrent - :return: question and answer - """ - - semaphore = asyncio.Semaphore(max_concurrent) - - async def _process_nodes_and_edges( - _process_nodes: list, - _process_edges: list, - ) -> str: - prompt = await _construct_rephrasing_prompt( - _process_nodes, _process_edges, text_chunks_storage, add_context=False - ) - context = await llm_client.generate_answer(prompt) - - # post-process the context - if context.startswith("Rephrased Text:"): - context = context[len("Rephrased Text:") :].strip() - elif context.startswith("重述文本:"): - context = context[len("重述文本:") :].strip() - - return context - - async def _process_single_batch( - _process_batch: tuple, question_type: str = "single" - ) -> dict: - async with semaphore: - context = await _process_nodes_and_edges( - _process_batch[0], - _process_batch[1], - ) - - language = "Chinese" if detect_main_language(context) == "zh" else "English" - pre_length = sum(node["length"] for node in _process_batch[0]) + sum( - edge[2]["length"] for edge in _process_batch[1] - ) - - if question_type == "single": - question = await llm_client.generate_answer( - QUESTION_GENERATION_PROMPT[language]["SINGLE_TEMPLATE"].format( - answer=context - ) - ) - if question.startswith("Question:"): - question = question[len("Question:") :].strip() - elif question.startswith("问题:"): - question = question[len("问题:") :].strip() - - logger.info( - "%d nodes and %d edges processed", - len(_process_batch[0]), - len(_process_batch[1]), - ) - logger.info("Pre-length: %s", pre_length) - logger.info("Question: %s", question) - logger.info("Answer: %s", context) - - return { - compute_content_hash(context): { - "question": question, - "answer": context, - "loss": get_average_loss( - _process_batch, traverse_strategy["loss_strategy"] - ), - } - } - - content = await llm_client.generate_answer( - QUESTION_GENERATION_PROMPT[language]["MULTI_TEMPLATE"].format( - doc=context - ) - ) - qas = _post_process_synthetic_data(content) - - if len(qas) == 0: - logger.error( - "Error occurred while processing batch, question or answer is None" - ) - return {} - - final_results = {} - logger.info( - "%d nodes and %d edges processed", - len(_process_batch[0]), - len(_process_batch[1]), - ) - logger.info("Pre-length: %s", pre_length) - for qa in qas: - logger.info("Question: %s", qa["question"]) - logger.info("Answer: %s", qa["answer"]) - final_results[compute_content_hash(qa["question"])] = { - "question": qa["question"], - "answer": qa["answer"], - "loss": get_average_loss( - _process_batch, traverse_strategy["loss_strategy"] - ), - } - return final_results - - results = {} - edges = list(await graph_storage.get_all_edges()) - nodes = list(await graph_storage.get_all_nodes()) - - edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes) - - processing_batches = await get_batches_with_strategy( - nodes, edges, graph_storage, traverse_strategy - ) - - for result in tqdm_async( - asyncio.as_completed( - [_process_single_batch(batch) for batch in processing_batches] - ), - total=len(processing_batches), - desc="[4/4]Generating QAs", - ): - try: - if progress_bar is not None: - progress_bar( - len(results) / len(processing_batches), desc="[4/4]Generating QAs" - ) - results.update(await result) - if progress_bar is not None and len(results) == len(processing_batches): - progress_bar(1, desc="[4/4]Generating QAs") - except Exception as e: # pylint: disable=broad-except - logger.error("Error occurred while generating QA: %s", e) - - return results - - -# pylint: disable=too-many-branches, too-many-statements -async def traverse_graph_for_atomic( - llm_client: OpenAIClient, - tokenizer: Tokenizer, - graph_storage: NetworkXStorage, - traverse_strategy: Dict, - text_chunks_storage: JsonKVStorage, - progress_bar: gr.Progress = None, - max_concurrent: int = 1000, -) -> dict: - """ - Traverse the graph atomicly - - :param llm_client - :param tokenizer - :param graph_storage - :param traverse_strategy - :param text_chunks_storage - :param progress_bar - :param max_concurrent - :return: question and answer - """ - - semaphore = asyncio.Semaphore(max_concurrent) - - def _parse_qa(qa: str) -> tuple: - if "Question:" in qa and "Answer:" in qa: - question = qa.split("Question:")[1].split("Answer:")[0].strip() - answer = qa.split("Answer:")[1].strip() - elif "问题:" in qa and "答案:" in qa: - question = qa.split("问题:")[1].split("答案:")[0].strip() - answer = qa.split("答案:")[1].strip() - else: - return None, None - return question.strip('"'), answer.strip('"') - - async def _generate_question(node_or_edge: tuple): - if len(node_or_edge) == 2: - des = node_or_edge[0] + ": " + node_or_edge[1]["description"] - loss = node_or_edge[1]["loss"] if "loss" in node_or_edge[1] else -1.0 - else: - des = node_or_edge[2]["description"] - loss = node_or_edge[2]["loss"] if "loss" in node_or_edge[2] else -1.0 - - async with semaphore: - try: - language = "Chinese" if detect_main_language(des) == "zh" else "English" - - qa = await llm_client.generate_answer( - QUESTION_GENERATION_PROMPT[language]["SINGLE_QA_TEMPLATE"].format( - doc=des - ) - ) - - question, answer = _parse_qa(qa) - if question is None or answer is None: - return {} - - question = question.strip('"') - answer = answer.strip('"') - - logger.info("Question: %s", question) - logger.info("Answer: %s", answer) - return { - compute_content_hash(question): { - "question": question, - "answer": answer, - "loss": loss, - } - } - except Exception as e: # pylint: disable=broad-except - logger.error("Error occurred while generating question: %s", e) - return {} - - results = {} - edges = list(await graph_storage.get_all_edges()) - nodes = list(await graph_storage.get_all_nodes()) - - edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes) - - tasks = [] - for node in nodes: - if "" in node[1]["description"]: - description_list = node[1]["description"].split("") - for item in description_list: - tasks.append((node[0], {"description": item})) - if "loss" in node[1]: - tasks[-1][1]["loss"] = node[1]["loss"] - else: - tasks.append((node[0], node[1])) - for edge in edges: - if "" in edge[2]["description"]: - description_list = edge[2]["description"].split("") - for item in description_list: - tasks.append((edge[0], edge[1], {"description": item})) - if "loss" in edge[2]: - tasks[-1][2]["loss"] = edge[2]["loss"] - else: - tasks.append((edge[0], edge[1], edge[2])) - - for result in tqdm_async( - asyncio.as_completed([_generate_question(task) for task in tasks]), - total=len(tasks), - desc="[4/4]Generating QAs", - ): - try: - if progress_bar is not None: - progress_bar(len(results) / len(tasks), desc="[4/4]Generating QAs") - results.update(await result) - if progress_bar is not None and len(results) == len(tasks): - progress_bar(1, desc="[4/4]Generating QAs") - except Exception as e: # pylint: disable=broad-except - logger.error("Error occurred while generating QA: %s", e) - return results - - -async def traverse_graph_for_multi_hop( - llm_client: OpenAIClient, - tokenizer: Tokenizer, - graph_storage: NetworkXStorage, - traverse_strategy: Dict, - text_chunks_storage: JsonKVStorage, - progress_bar: gr.Progress = None, - max_concurrent: int = 1000, -) -> dict: - """ - Traverse the graph for multi-hop - - :param llm_client - :param tokenizer - :param graph_storage - :param traverse_strategy - :param text_chunks_storage - :param progress_bar - :param max_concurrent - :return: question and answer - """ - semaphore = asyncio.Semaphore(max_concurrent) - - results = {} - edges = list(await graph_storage.get_all_edges()) - nodes = list(await graph_storage.get_all_nodes()) - - edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes) - - processing_batches = await get_batches_with_strategy( - nodes, edges, graph_storage, traverse_strategy - ) - - async def _process_single_batch(_process_batch: tuple) -> dict: - async with semaphore: - try: - language = ( - "Chinese" - if detect_main_language(_process_batch[0][0]["description"]) == "zh" - else "English" - ) - - _process_nodes = _process_batch[0] - _process_edges = _process_batch[1] - - entities = [ - f"{_process_node['node_id']}: {_process_node['description']}" - for _process_node in _process_nodes - ] - - relations = [ - f"{_process_edge[0]} -- {_process_edge[1]}: {_process_edge[2]['description']}" - for _process_edge in _process_edges - ] - - entities_str = "\n".join( - [f"{index + 1}. {entity}" for index, entity in enumerate(entities)] - ) - relations_str = "\n".join( - [ - f"{index + 1}. {relation}" - for index, relation in enumerate(relations) - ] - ) - - prompt = MULTI_HOP_GENERATION_PROMPT[language].format( - entities=entities_str, relationships=relations_str - ) - - context = await llm_client.generate_answer(prompt) - - # post-process the context - if "Question:" in context and "Answer:" in context: - question = context.split("Question:")[1].split("Answer:")[0].strip() - answer = context.split("Answer:")[1].strip() - elif "问题:" in context and "答案:" in context: - question = context.split("问题:")[1].split("答案:")[0].strip() - answer = context.split("答案:")[1].strip() - else: - return {} - - question = question.strip('"') - answer = answer.strip('"') - - logger.info("Question: %s", question) - logger.info("Answer: %s", answer) - - return { - compute_content_hash(question): { - "question": question, - "answer": answer, - "loss": get_average_loss( - _process_batch, traverse_strategy["loss_strategy"] - ), - } - } - - except Exception as e: # pylint: disable=broad-except - logger.error("Error occurred while processing batch: %s", e) - return {} - - async for result in tqdm_async( - asyncio.as_completed( - [_process_single_batch(batch) for batch in processing_batches] - ), - total=len(processing_batches), - desc="[4/4]Generating QAs", - ): - try: - if progress_bar is not None: - progress_bar( - len(results) / len(processing_batches), desc="[4/4]Generating QAs" - ) - results.update(await result) - if progress_bar is not None and len(results) == len(processing_batches): - progress_bar(1, desc="[4/4]Generating QAs") - except Exception as e: # pylint: disable=broad-except - logger.error("Error occurred while generating QA: %s", e) - return results diff --git a/graphgen/templates/__init__.py b/graphgen/templates/__init__.py index a3d1e9ed5dfd20f0f08cb6c39f40bb1794b80ca4..8f764cc0d8432f94cd1439e2323e9d7ddcb4c932 100644 --- a/graphgen/templates/__init__.py +++ b/graphgen/templates/__init__.py @@ -1,10 +1,13 @@ -from .answer_rephrasing import ANSWER_REPHRASING_PROMPT -from .community import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT +from .generation import ( + AGGREGATED_GENERATION_PROMPT, + ATOMIC_GENERATION_PROMPT, + COT_GENERATION_PROMPT, + MULTI_HOP_GENERATION_PROMPT, +) from .kg_extraction import KG_EXTRACTION_PROMPT from .kg_summarization import KG_SUMMARIZATION_PROMPT -from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT from .question_generation import QUESTION_GENERATION_PROMPT from .search_judgement import SEARCH_JUDGEMENT_PROMPT from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT diff --git a/graphgen/templates/community/__init__.py b/graphgen/templates/community/__init__.py deleted file mode 100644 index 4721d03e285fe2ef43778e808cf2481c86ddb78e..0000000000000000000000000000000000000000 --- a/graphgen/templates/community/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .cot_generation import COT_GENERATION_PROMPT -from .cot_template_design import COT_TEMPLATE_DESIGN_PROMPT diff --git a/graphgen/templates/community/cot_generation.py b/graphgen/templates/community/cot_generation.py deleted file mode 100644 index 0494cd805d605d5ebdf5dd24b06bd90111c2a292..0000000000000000000000000000000000000000 --- a/graphgen/templates/community/cot_generation.py +++ /dev/null @@ -1,87 +0,0 @@ -TEMPLATE_ZH = """根据给定的知识图谱原始信息及已生成的推理路径,产出一条符合模板要求、可直接用于下游训练或推理的 CoT 数据。\ -CoT(Chain-of-Thought,思维链)指在回答复杂问题时,把中间推理步骤一步一步显式写出来,使推理过程透明、可追溯,而不是直接给出最终答案。 - --输入格式- -[Entities:] -(实体名:实体描述) -... - -[Relationships:] -(来源实体)-[关系描述]->(目标实体) -... - -[Question and Reasoning Path:] -(问题) -(推理路径) - --输出要求- -1. 每一步只完成一个不可分割的子任务,并用自然语言衔接,但是要避免生硬的连接词。 -2. 使用中文。 -3. 不要使用有序列表或编号。 -4. 请直接给出答案,不要生成无关信息。 - --真实数据- -输入: -[Entities:]: -{entities} - -[Relationships:]: -{relationships} - -[Question:]: -{question} - -[Reasoning_Template:]: -{reasoning_template} - -输出: - -""" - -TEMPLATE_EN = """Given the raw knowledge graph information and the provided reasoning-path, \ -produce one Chain-of-Thought (CoT) sample that strictly follows the template \ -and can be directly used for downstream training or inference. -CoT (Chain-of-Thought) means that when answering a complex question, the intermediate reasoning steps are \ -explicitly written out one by one, making the reasoning process transparent and traceable instead of giving \ -only the final answer. - --Input Format- -[Entities:]: -(ENTITY_NAME: ENTITY_DESCRIPTION) -... - -[Relationships:]: -(ENTITY_SOURCE)-[RELATIONSHIP_DESCRIPTION]->(ENTITY_TARGET) -... - -[Question and Reasoning Path:]: -(QUESTION) -(REASONING_PATH) - --Output Requirements- -1. Each step completes a single, indivisible sub-task and is naturally connected, avoiding abrupt transition words. -2. Use English. -3. Do not use ordered lists or numbering. -4. Do not generate extraneous information, just provide the answer. - --Real Data- -Input: -[Entities:]: -{entities} - -[Relationships:]: -{relationships} - -[Question:]: -{question} - -[Reasoning_Template:]: -{reasoning_template} - -Output: -""" - -COT_GENERATION_PROMPT = { - "Chinese": {"TEMPLATE": TEMPLATE_ZH}, - "English": {"TEMPLATE": TEMPLATE_EN}, -} diff --git a/graphgen/templates/generation/__init__.py b/graphgen/templates/generation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b624d3bc8ed191b55a791e67beb6f2d63316afa --- /dev/null +++ b/graphgen/templates/generation/__init__.py @@ -0,0 +1,4 @@ +from .aggregated_generation import AGGREGATED_GENERATION_PROMPT +from .atomic_generation import ATOMIC_GENERATION_PROMPT +from .cot_generation import COT_GENERATION_PROMPT +from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT diff --git a/graphgen/templates/answer_rephrasing.py b/graphgen/templates/generation/aggregated_generation.py similarity index 85% rename from graphgen/templates/answer_rephrasing.py rename to graphgen/templates/generation/aggregated_generation.py index fc988fa25edeedca98674268c9403c50f2ebb995..9e1bfac8376e3b6bf6fb2bd6bde30c5da4c3c81d 100644 --- a/graphgen/templates/answer_rephrasing.py +++ b/graphgen/templates/generation/aggregated_generation.py @@ -1,4 +1,5 @@ -TEMPLATE_CONTEXT_EN: str = """---Role--- +# pylint: disable=C0301 +ANSWER_REPHRASING_CONTEXT_EN: str = """---Role--- You are an NLP expert responsible for generating a logically structured and coherent rephrased version of the TEXT based on ENTITIES and RELATIONSHIPS provided below. You may refer to the original text to assist in generating the rephrased version, but ensure that the final output text meets the requirements. Use {language} as output language. @@ -49,7 +50,7 @@ To generate a version of the text that is rephrased and conveys the same meaning """ -TEMPLATE_CONTEXT_ZH: str = """---角色--- +ANSWER_REPHRASING_CONTEXT_ZH: str = """---角色--- 你是一位NLP专家,负责根据下面提供的实体和关系生成逻辑结构清晰且连贯的文本重述版本。你可以参考原始文本辅助生成,但需要确保最终输出的文本符合要求。 使用{language}作为输出语言。 @@ -97,7 +98,7 @@ TEMPLATE_CONTEXT_ZH: str = """---角色--- """ -TEMPLATE_EN: str = """---Role--- +ANSWER_REPHRASING_EN: str = """---Role--- You are an NLP expert responsible for generating a logically structured and coherent rephrased version of the TEXT based on ENTITIES and RELATIONSHIPS provided below. Use {language} as output language. @@ -143,7 +144,7 @@ To generate a version of the text that is rephrased and conveys the same meaning """ -TEMPLATE_ZH: str = """---角色--- +ANSWER_REPHRASING_ZH: str = """---角色--- 你是一位NLP专家,负责根据下面提供的实体和关系生成逻辑结构清晰且连贯的文本重述版本。 使用{language}作为输出语言。 @@ -200,14 +201,33 @@ Please directly output the coherent rephrased text below, without any additional Rephrased Text: """ +QUESTION_GENERATION_EN: str = """The answer to a question is provided. Please generate a question that corresponds to the answer. -ANSWER_REPHRASING_PROMPT = { - "English": { - "TEMPLATE": TEMPLATE_EN + REQUIREMENT_EN, - "CONTEXT_TEMPLATE": TEMPLATE_CONTEXT_EN + REQUIREMENT_EN, +################ +Answer: +{answer} +################ +Question: +""" + +QUESTION_GENERATION_ZH: str = """下面提供了一个问题的答案,请生成一个与答案对应的问题。 + +################ +答案: +{answer} +################ +问题: +""" + +AGGREGATED_GENERATION_PROMPT = { + "en": { + "ANSWER_REPHRASING": ANSWER_REPHRASING_EN + REQUIREMENT_EN, + "ANSWER_REPHRASING_CONTEXT": ANSWER_REPHRASING_CONTEXT_EN + REQUIREMENT_EN, + "QUESTION_GENERATION": QUESTION_GENERATION_EN, }, - "Chinese": { - "TEMPLATE": TEMPLATE_ZH + REQUIREMENT_ZH, - "CONTEXT_TEMPLATE": TEMPLATE_CONTEXT_ZH + REQUIREMENT_ZH, + "zh": { + "ANSWER_REPHRASING": ANSWER_REPHRASING_ZH + REQUIREMENT_ZH, + "ANSWER_REPHRASING_CONTEXT": ANSWER_REPHRASING_CONTEXT_ZH + REQUIREMENT_ZH, + "QUESTION_GENERATION": QUESTION_GENERATION_ZH, }, } diff --git a/graphgen/templates/generation/atomic_generation.py b/graphgen/templates/generation/atomic_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..499100f79d954004c22c75cc3ec88bc899b604f2 --- /dev/null +++ b/graphgen/templates/generation/atomic_generation.py @@ -0,0 +1,32 @@ +# pylint: disable=C0301 +TEMPLATE_EN: str = """You are given a text passage. Your task is to generate a question and answer (QA) pair based on the content of that text. +The answer should be accurate and directly derived from the text. Make sure the QA pair is relevant to the main theme or important details of the given text. +For example: +Question: What is the effect of overexpressing the BG1 gene on grain size and development? +Answer: Overexpression of the BG1 gene leads to significantly increased grain size, demonstrating its role in grain development. + +Question: What role does TAC4 play in the gravitropism of rice shoots? +Answer: TAC4 is a key regulator of gravitropism in rice shoots, promoting the bending of shoots towards the gravity vector. + +Here is the text passage you need to generate a QA pair for: +{context} +""" + +TEMPLATE_ZH: str = """给定一个文本段落。你的任务是根据该文本的内容生成一个问答(QA)对。 +答案应准确且直接从文本中得出。确保QA对与给定文本的主题或重要细节相关。 +例如: +问题:过表达BG1基因对谷粒大小和发育有什么影响? +答案:BG1基因的过表达显著增加了谷粒大小,表明其在谷物发育中的作用。 + +问题:TAC4在水稻茎的重力性状中扮演什么角色? +答案:TAC4是水稻茎重力性状的关键调节因子,促进茎向重力矢量弯曲。 + +以下是你需要为其生成QA对的文本段落: +{context} +""" + + +ATOMIC_GENERATION_PROMPT = { + "en": TEMPLATE_EN, + "zh": TEMPLATE_ZH, +} diff --git a/graphgen/templates/community/cot_template_design.py b/graphgen/templates/generation/cot_generation.py similarity index 61% rename from graphgen/templates/community/cot_template_design.py rename to graphgen/templates/generation/cot_generation.py index 04cfa2309c7f035124477084734a38c1b9a6a5d1..849a7c71b043a61e206bf7442f5055aa8633ecf5 100644 --- a/graphgen/templates/community/cot_template_design.py +++ b/graphgen/templates/generation/cot_generation.py @@ -1,4 +1,87 @@ -TEMPLATE_ZH = """你是一位“元推理架构师”。你的任务不是回答问题,\ +COT_GENERATION_ZH = """根据给定的知识图谱原始信息及已生成的推理路径,产出一条符合模板要求、可直接用于下游训练或推理的 CoT 数据。\ +CoT(Chain-of-Thought,思维链)指在回答复杂问题时,把中间推理步骤一步一步显式写出来,使推理过程透明、可追溯,而不是直接给出最终答案。 + +-输入格式- +[Entities:] +(实体名:实体描述) +... + +[Relationships:] +(来源实体)-[关系描述]->(目标实体) +... + +[Question and Reasoning Path:] +(问题) +(推理路径) + +-输出要求- +1. 每一步只完成一个不可分割的子任务,并用自然语言衔接,但是要避免生硬的连接词。 +2. 使用中文。 +3. 不要使用有序列表或编号。 +4. 请直接给出答案,不要生成无关信息。 + +-真实数据- +输入: +[Entities:]: +{entities} + +[Relationships:]: +{relationships} + +[Question:]: +{question} + +[Reasoning_Template:]: +{reasoning_template} + +输出: + +""" + +COT_GENERATION_EN = """Given the raw knowledge graph information and the provided reasoning-path, \ +produce one Chain-of-Thought (CoT) sample that strictly follows the template \ +and can be directly used for downstream training or inference. +CoT (Chain-of-Thought) means that when answering a complex question, the intermediate reasoning steps are \ +explicitly written out one by one, making the reasoning process transparent and traceable instead of giving \ +only the final answer. + +-Input Format- +[Entities:]: +(ENTITY_NAME: ENTITY_DESCRIPTION) +... + +[Relationships:]: +(ENTITY_SOURCE)-[RELATIONSHIP_DESCRIPTION]->(ENTITY_TARGET) +... + +[Question and Reasoning Path:]: +(QUESTION) +(REASONING_PATH) + +-Output Requirements- +1. Each step completes a single, indivisible sub-task and is naturally connected, avoiding abrupt transition words. +2. Use English. +3. Do not use ordered lists or numbering. +4. Do not generate extraneous information, just provide the answer. + +-Real Data- +Input: +[Entities:]: +{entities} + +[Relationships:]: +{relationships} + +[Question:]: +{question} + +[Reasoning_Template:]: +{reasoning_template} + +Output: +""" + +COT_TEMPLATE_DESIGN_ZH = """你是一位“元推理架构师”。你的任务不是回答问题,\ 而是根据给定的知识图谱中的实体和关系的名称以及描述信息,设计一条可复用、可泛化的 CoT 推理路径模板。\ -步骤- @@ -47,7 +130,7 @@ TEMPLATE_ZH = """你是一位“元推理架构师”。你的任务不是回答 """ -TEMPLATE_EN = """You are a “meta-reasoning architect”. \ +COT_TEMPLATE_DESIGN_EN = """You are a “meta-reasoning architect”. \ Your task is NOT to answer the question, but to design a reusable, generalizable CoT reasoning-path \ template based solely on the names and descriptions of entities and \ relationships in the provided knowledge graph. @@ -101,7 +184,13 @@ Input: Output: """ -COT_TEMPLATE_DESIGN_PROMPT = { - "Chinese": {"TEMPLATE": TEMPLATE_ZH}, - "English": {"TEMPLATE": TEMPLATE_EN}, +COT_GENERATION_PROMPT = { + "en": { + "COT_GENERATION": COT_GENERATION_EN, + "COT_TEMPLATE_DESIGN": COT_TEMPLATE_DESIGN_EN, + }, + "zh": { + "COT_GENERATION": COT_GENERATION_ZH, + "COT_TEMPLATE_DESIGN": COT_TEMPLATE_DESIGN_ZH, + }, } diff --git a/graphgen/templates/multi_hop_generation.py b/graphgen/templates/generation/multi_hop_generation.py similarity index 95% rename from graphgen/templates/multi_hop_generation.py rename to graphgen/templates/generation/multi_hop_generation.py index dad2ee36204f8eae483a99d39d55c2b04ba879b9..73857ebb7245af5de6263edcb647ec468c88371d 100644 --- a/graphgen/templates/multi_hop_generation.py +++ b/graphgen/templates/generation/multi_hop_generation.py @@ -1,5 +1,4 @@ # pylint: disable=C0301 - TEMPLATE_ZH: str = """请基于以下知识子图生成多跳推理问题和答案。你将获得一个知识子图,其中包含一系列实体、关系和事实。你的任务是提出一个问题,该问题需要经过多次推理才能回答。问题的答案应该是从给定的知识子图中推断出来的。确保问题的难度适中,需要多步推理才能回答。 例如: @@ -54,7 +53,4 @@ Answer: Vitamin C Output the generated question and answer directly, please do not copy the example question and answer directly, and do not provide irrelevant information. """ -MULTI_HOP_GENERATION_PROMPT = { - "English": TEMPLATE_EN, - "Chinese": TEMPLATE_ZH -} +MULTI_HOP_GENERATION_PROMPT = {"en": TEMPLATE_EN, "zh": TEMPLATE_ZH} diff --git a/graphgen/templates/question_generation.py b/graphgen/templates/question_generation.py index d9ca9128fad65c127f985080f887351cf2efe68e..e75bf16975d05e3d49d046c6f52b0402ed23062c 100644 --- a/graphgen/templates/question_generation.py +++ b/graphgen/templates/question_generation.py @@ -1,47 +1,5 @@ # pylint: disable=C0301 -TEMPLATE_SINGLE_EN: str = """The answer to a question is provided. Please generate a question that corresponds to the answer. -################ -Answer: -{answer} -################ -Question: -""" - -TEMPLATE_SINGLE_ZH: str = """下面提供了一个问题的答案,请生成一个与答案对应的问题。 - -################ -答案: -{answer} -################ -问题: -""" - -TEMPLATE_SINGLE_QA_EN: str = """You are given a text passage. Your task is to generate a question and answer (QA) pair based on the content of that text. -The answer should be accurate and directly derived from the text. Make sure the QA pair is relevant to the main theme or important details of the given text. -For example: -Question: What is the effect of overexpressing the BG1 gene on grain size and development? -Answer: Overexpression of the BG1 gene leads to significantly increased grain size, demonstrating its role in grain development. - -Question: What role does TAC4 play in the gravitropism of rice shoots? -Answer: TAC4 is a key regulator of gravitropism in rice shoots, promoting the bending of shoots towards the gravity vector. - -Here is the text passage you need to generate a QA pair for: -{doc} -""" - -TEMPLATE_SINGLE_QA_ZH: str = """给定一个文本段落。你的任务是根据该文本的内容生成一个问答(QA)对。 -答案应准确且直接从文本中得出。确保QA对与给定文本的主题或重要细节相关。 -例如: -问题:过表达BG1基因对谷粒大小和发育有什么影响? -答案:BG1基因的过表达显著增加了谷粒大小,表明其在谷物发育中的作用。 - -问题:TAC4在水稻茎的重力性状中扮演什么角色? -答案:TAC4是水稻茎重力性状的关键调节因子,促进茎向重力矢量弯曲。 - -以下是你需要为其生成QA对的文本段落: -{doc} -""" # TODO: 修改这里的prompt TEMPLATE_MULTI_EN = """You are an assistant to help read a article and then rephrase it in a question answering format. The user will provide you with an article with its content. You need to generate a paraphrase of the same article in question and answer format with one tag of "Question: ..." followed by "Answer: ...". Remember to keep the meaning and every content of the article intact. @@ -66,13 +24,9 @@ TEMPLATE_MULTI_ZH = """你是一位助手,帮助阅读一篇文章,然后以 QUESTION_GENERATION_PROMPT = { "English": { - "SINGLE_TEMPLATE": TEMPLATE_SINGLE_EN, - "SINGLE_QA_TEMPLATE": TEMPLATE_SINGLE_QA_EN, - "MULTI_TEMPLATE": TEMPLATE_MULTI_EN + "MULTI_TEMPLATE": TEMPLATE_MULTI_EN, }, "Chinese": { - "SINGLE_TEMPLATE": TEMPLATE_SINGLE_ZH, - "SINGLE_QA_TEMPLATE": TEMPLATE_SINGLE_QA_ZH, - "MULTI_TEMPLATE": TEMPLATE_MULTI_ZH - } + "MULTI_TEMPLATE": TEMPLATE_MULTI_ZH, + }, } diff --git a/graphgen/utils/__init__.py b/graphgen/utils/__init__.py index d56ca734463dfda9d796a216dd314f831e0039e9..3d80d2dfe4c91a798b73b8c74f36e9a19b9fe419 100644 --- a/graphgen/utils/__init__.py +++ b/graphgen/utils/__init__.py @@ -1,7 +1,6 @@ from .calculate_confidence import yes_no_loss_entropy from .detect_lang import detect_if_chinese, detect_main_language from .format import ( - format_generation_results, handle_single_entity_extraction, handle_single_relationship_extraction, load_json, diff --git a/graphgen/utils/format.py b/graphgen/utils/format.py index abc34c874a5b413a478e513d9f5109241f36c8a8..1f0675f13b120af7d2946d41fe9d72b0fac8ee8d 100644 --- a/graphgen/utils/format.py +++ b/graphgen/utils/format.py @@ -4,8 +4,6 @@ import os import re from typing import Any -from .log import logger - def pack_history_conversations(*args: str): roles = ["user", "assistant"] @@ -92,43 +90,3 @@ def write_json(json_obj, file_name): os.makedirs(os.path.dirname(file_name), exist_ok=True) with open(file_name, "w", encoding="utf-8") as f: json.dump(json_obj, f, indent=4, ensure_ascii=False) - - -def format_generation_results( - results: dict[str, Any], output_data_format: str -) -> list[dict[str, Any]]: - if output_data_format == "Alpaca": - logger.info("Output data format: Alpaca") - results = [ - { - "instruction": item["question"], - "input": "", - "output": item["answer"], - } - for item in list(results.values()) - ] - elif output_data_format == "Sharegpt": - logger.info("Output data format: Sharegpt") - results = [ - { - "conversations": [ - {"from": "human", "value": item["question"]}, - {"from": "gpt", "value": item["answer"]}, - ] - } - for item in list(results.values()) - ] - elif output_data_format == "ChatML": - logger.info("Output data format: ChatML") - results = [ - { - "messages": [ - {"role": "user", "content": item["question"]}, - {"role": "assistant", "content": item["answer"]}, - ] - } - for item in list(results.values()) - ] - else: - raise ValueError(f"Unknown output data format: {output_data_format}") - return results diff --git a/webui/app.py b/webui/app.py index 1179a7d0fbd5e8ad72b12ea9585a6690e348e978..7af93c3d970e0fe6c1f62a9a3f3cf234d5562e73 100644 --- a/webui/app.py +++ b/webui/app.py @@ -468,7 +468,7 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo: label="TPM", minimum=5000, maximum=5000000, - value=50000, + value=100000, step=1000, interactive=True, visible=True, diff --git a/webui/translation.json b/webui/translation.json index c51659b0f19303f53e8445741eed1736e3932a33..30cf2f11e00b38766a00a53ab946729c5b37d595 100644 --- a/webui/translation.json +++ b/webui/translation.json @@ -37,7 +37,7 @@ "Generation Config": "生成配置", "API Config": "API Config", "### ": "### ", - "SiliconFlow Token": "SiliconFlow Token", + "SiliconFlow Token": "硅基流动 API 秘钥", "Upload File": "上传文件", "Example Files": "示例文件", "Output File": "输出文件",