Spaces:
Running
Running
File size: 1,777 Bytes
799ac7c 8e67692 d02622b 799ac7c e4316f1 799ac7c d02622b 799ac7c 8e67692 799ac7c 2a0edfe 799ac7c f1eedd1 799ac7c f1eedd1 799ac7c f1eedd1 799ac7c f1eedd1 799ac7c f1eedd1 799ac7c f1eedd1 e4316f1 799ac7c f1eedd1 799ac7c 2a0edfe 799ac7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
from typing import Any
import gradio as gr
from graphgen.bases import BaseLLMWrapper
from graphgen.models import (
AggregatedGenerator,
AtomicGenerator,
CoTGenerator,
MultiHopGenerator,
VQAGenerator,
)
from graphgen.utils import logger, run_concurrent
async def generate_qas(
llm_client: BaseLLMWrapper,
batches: list[
tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
]
],
generation_config: dict,
progress_bar: gr.Progress = None,
) -> list[dict[str, Any]]:
"""
Generate question-answer pairs based on nodes and edges.
:param llm_client: LLM client
:param batches
:param generation_config
:param progress_bar
:return: QA pairs
"""
method = generation_config["method"]
logger.info("[Generation] mode: %s, batches: %d", method, len(batches))
if method == "atomic":
generator = AtomicGenerator(llm_client)
elif method == "aggregated":
generator = AggregatedGenerator(llm_client)
elif method == "multi_hop":
generator = MultiHopGenerator(llm_client)
elif method == "cot":
generator = CoTGenerator(llm_client)
elif method in ["vqa"]:
generator = VQAGenerator(llm_client)
else:
raise ValueError(f"Unsupported generation mode: {method}")
results = await run_concurrent(
generator.generate,
batches,
desc="[4/4]Generating QAs",
unit="batch",
progress_bar=progress_bar,
)
# format
data_format = generation_config["data_format"]
logger.info("Output data format: %s", data_format)
results = generator.format_generation_results(
results, output_data_format=data_format
)
return results
|