GraphGen / graphgen /graphgen.py
github-actions[bot]
Auto-sync from demo at Mon Dec 1 10:51:45 UTC 2025
06c3276
import os
import time
from typing import Dict
import gradio as gr
from graphgen.bases import BaseLLMWrapper
from graphgen.bases.datatypes import Chunk
from graphgen.models import (
JsonKVStorage,
JsonListStorage,
NetworkXStorage,
OpenAIClient,
Tokenizer,
)
from graphgen.operators import (
build_kg,
chunk_documents,
extract_info,
generate_qas,
init_llm,
judge_statement,
partition_kg,
quiz,
read_files,
search_all,
)
from graphgen.utils import async_to_sync_method, compute_mm_hash, logger
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
class GraphGen:
def __init__(
self,
unique_id: int = int(time.time()),
working_dir: str = os.path.join(sys_path, "cache"),
tokenizer_instance: Tokenizer = None,
synthesizer_llm_client: OpenAIClient = None,
trainee_llm_client: OpenAIClient = None,
progress_bar: gr.Progress = None,
):
self.unique_id: int = unique_id
self.working_dir: str = working_dir
# llm
self.tokenizer_instance: Tokenizer = tokenizer_instance or Tokenizer(
model_name=os.getenv("TOKENIZER_MODEL", "cl100k_base")
)
self.synthesizer_llm_client: BaseLLMWrapper = (
synthesizer_llm_client or init_llm("synthesizer")
)
self.trainee_llm_client: BaseLLMWrapper = trainee_llm_client
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="full_docs"
)
self.chunks_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="chunks"
)
self.graph_storage: NetworkXStorage = NetworkXStorage(
self.working_dir, namespace="graph"
)
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
self.working_dir, namespace="rephrase"
)
self.partition_storage: JsonListStorage = JsonListStorage(
self.working_dir, namespace="partition"
)
self.search_storage: JsonKVStorage = JsonKVStorage(
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
namespace="search",
)
self.qa_storage: JsonListStorage = JsonListStorage(
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
namespace="qa",
)
self.extract_storage: JsonKVStorage = JsonKVStorage(
os.path.join(self.working_dir, "data", "graphgen", f"{self.unique_id}"),
namespace="extraction",
)
# webui
self.progress_bar: gr.Progress = progress_bar
@async_to_sync_method
async def read(self, read_config: Dict):
"""
read files from input sources
"""
doc_stream = read_files(**read_config, cache_dir=self.working_dir)
batch = {}
for doc in doc_stream:
doc_id = compute_mm_hash(doc, prefix="doc-")
batch[doc_id] = doc
# TODO: configurable whether to use coreference resolution
_add_doc_keys = self.full_docs_storage.filter_keys(list(batch.keys()))
new_docs = {k: v for k, v in batch.items() if k in _add_doc_keys}
if len(new_docs) == 0:
logger.warning("All documents are already in the storage")
return
self.full_docs_storage.upsert(new_docs)
self.full_docs_storage.index_done_callback()
@async_to_sync_method
async def chunk(self, chunk_config: Dict):
"""
chunk documents into smaller pieces from full_docs_storage if not already present
"""
new_docs = self.full_docs_storage.get_all()
if len(new_docs) == 0:
logger.warning("All documents are already in the storage")
return
inserting_chunks = await chunk_documents(
new_docs,
self.tokenizer_instance,
self.progress_bar,
**chunk_config,
)
_add_chunk_keys = self.chunks_storage.filter_keys(list(inserting_chunks.keys()))
inserting_chunks = {
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
}
if len(inserting_chunks) == 0:
logger.warning("All chunks are already in the storage")
return
self.chunks_storage.upsert(inserting_chunks)
self.chunks_storage.index_done_callback()
@async_to_sync_method
async def build_kg(self):
"""
build knowledge graph from text chunks
"""
# Step 1: get new chunks
inserting_chunks = self.chunks_storage.get_all()
if len(inserting_chunks) == 0:
logger.warning("All chunks are already in the storage")
return
logger.info("[New Chunks] inserting %d chunks", len(inserting_chunks))
# Step 2: build knowledge graph from new chunks
_add_entities_and_relations = await build_kg(
llm_client=self.synthesizer_llm_client,
kg_instance=self.graph_storage,
chunks=[Chunk.from_dict(k, v) for k, v in inserting_chunks.items()],
progress_bar=self.progress_bar,
)
if not _add_entities_and_relations:
logger.warning("No entities or relations extracted from text chunks")
return
# Step 3: upsert new entities and relations to the graph storage
self.graph_storage.index_done_callback()
return _add_entities_and_relations
@async_to_sync_method
async def search(self, search_config: Dict):
logger.info("[Search] %s ...", ", ".join(search_config["data_sources"]))
seeds = self.full_docs_storage.get_all()
if len(seeds) == 0:
logger.warning("All documents are already been searched")
return
search_results = await search_all(
seed_data=seeds,
search_config=search_config,
)
_add_search_keys = self.search_storage.filter_keys(list(search_results.keys()))
search_results = {
k: v for k, v in search_results.items() if k in _add_search_keys
}
if len(search_results) == 0:
logger.warning("All search results are already in the storage")
return
self.search_storage.upsert(search_results)
self.search_storage.index_done_callback()
@async_to_sync_method
async def quiz_and_judge(self, quiz_and_judge_config: Dict):
logger.warning(
"Quiz and Judge operation needs trainee LLM client."
" Make sure to provide one."
)
max_samples = quiz_and_judge_config["quiz_samples"]
await quiz(
self.synthesizer_llm_client,
self.graph_storage,
self.rephrase_storage,
max_samples,
progress_bar=self.progress_bar,
)
# TODO: assert trainee_llm_client is valid before judge
if not self.trainee_llm_client:
# TODO: shutdown existing synthesizer_llm_client properly
logger.info("No trainee LLM client provided, initializing a new one.")
self.synthesizer_llm_client.shutdown()
self.trainee_llm_client = init_llm("trainee")
re_judge = quiz_and_judge_config["re_judge"]
_update_relations = await judge_statement(
self.trainee_llm_client,
self.graph_storage,
self.rephrase_storage,
re_judge,
progress_bar=self.progress_bar,
)
self.rephrase_storage.index_done_callback()
_update_relations.index_done_callback()
logger.info("Shutting down trainee LLM client.")
self.trainee_llm_client.shutdown()
self.trainee_llm_client = None
logger.info("Restarting synthesizer LLM client.")
self.synthesizer_llm_client.restart()
@async_to_sync_method
async def partition(self, partition_config: Dict):
batches = await partition_kg(
self.graph_storage,
self.chunks_storage,
self.tokenizer_instance,
partition_config,
)
self.partition_storage.upsert(batches)
return batches
@async_to_sync_method
async def extract(self, extract_config: Dict):
logger.info("Extracting information from given chunks...")
results = await extract_info(
self.synthesizer_llm_client,
self.chunks_storage,
extract_config,
progress_bar=self.progress_bar,
)
if not results:
logger.warning("No information extracted")
return
self.extract_storage.upsert(results)
self.extract_storage.index_done_callback()
@async_to_sync_method
async def generate(self, generate_config: Dict):
batches = self.partition_storage.data
if not batches:
logger.warning("No partitions found for QA generation")
return
# Step 2: generate QA pairs
results = await generate_qas(
self.synthesizer_llm_client,
batches,
generate_config,
progress_bar=self.progress_bar,
)
if not results:
logger.warning("No QA pairs generated")
return
# Step 3: store the generated QA pairs
self.qa_storage.upsert(results)
self.qa_storage.index_done_callback()
@async_to_sync_method
async def clear(self):
self.full_docs_storage.drop()
self.chunks_storage.drop()
self.search_storage.drop()
self.graph_storage.clear()
self.rephrase_storage.drop()
self.qa_storage.drop()
logger.info("All caches are cleared")
# TODO: add data filtering step here in the future
# graph_gen.filter(filter_config=config["filter"])