Spaces:
Running
Running
File size: 1,741 Bytes
799ac7c d02622b 799ac7c e4316f1 799ac7c d02622b 799ac7c 2a0edfe 799ac7c 2a0edfe 799ac7c f1eedd1 799ac7c f1eedd1 799ac7c f1eedd1 799ac7c f1eedd1 799ac7c f1eedd1 799ac7c f1eedd1 e4316f1 799ac7c f1eedd1 799ac7c 2a0edfe 799ac7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from typing import Any
from graphgen.bases import BaseLLMWrapper
from graphgen.models import (
AggregatedGenerator,
AtomicGenerator,
CoTGenerator,
MultiHopGenerator,
VQAGenerator,
)
from graphgen.utils import logger, run_concurrent
async def generate_qas(
llm_client: BaseLLMWrapper,
batches: list[
tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
]
],
generation_config: dict,
progress_bar=None,
) -> list[dict[str, Any]]:
"""
Generate question-answer pairs based on nodes and edges.
:param llm_client: LLM client
:param batches
:param generation_config
:param progress_bar
:return: QA pairs
"""
method = generation_config["method"]
logger.info("[Generation] mode: %s, batches: %d", method, len(batches))
if method == "atomic":
generator = AtomicGenerator(llm_client)
elif method == "aggregated":
generator = AggregatedGenerator(llm_client)
elif method == "multi_hop":
generator = MultiHopGenerator(llm_client)
elif method == "cot":
generator = CoTGenerator(llm_client)
elif method in ["vqa"]:
generator = VQAGenerator(llm_client)
else:
raise ValueError(f"Unsupported generation mode: {method}")
results = await run_concurrent(
generator.generate,
batches,
desc="[4/4]Generating QAs",
unit="batch",
progress_bar=progress_bar,
)
# format
data_format = generation_config["data_format"]
logger.info("Output data format: %s", data_format)
results = generator.format_generation_results(
results, output_data_format=data_format
)
return results
|