File size: 1,741 Bytes
799ac7c
 
d02622b
799ac7c
 
 
 
 
e4316f1
799ac7c
 
 
 
 
d02622b
799ac7c
 
 
 
 
 
2a0edfe
799ac7c
 
 
 
 
 
2a0edfe
799ac7c
 
f1eedd1
 
799ac7c
f1eedd1
799ac7c
f1eedd1
799ac7c
f1eedd1
799ac7c
f1eedd1
799ac7c
f1eedd1
e4316f1
799ac7c
f1eedd1
799ac7c
 
 
 
 
 
2a0edfe
799ac7c
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import Any

from graphgen.bases import BaseLLMWrapper
from graphgen.models import (
    AggregatedGenerator,
    AtomicGenerator,
    CoTGenerator,
    MultiHopGenerator,
    VQAGenerator,
)
from graphgen.utils import logger, run_concurrent


async def generate_qas(
    llm_client: BaseLLMWrapper,
    batches: list[
        tuple[
            list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
        ]
    ],
    generation_config: dict,
    progress_bar=None,
) -> list[dict[str, Any]]:
    """
    Generate question-answer pairs based on nodes and edges.
    :param llm_client: LLM client
    :param batches
    :param generation_config
    :param progress_bar
    :return: QA pairs
    """
    method = generation_config["method"]
    logger.info("[Generation] mode: %s, batches: %d", method, len(batches))

    if method == "atomic":
        generator = AtomicGenerator(llm_client)
    elif method == "aggregated":
        generator = AggregatedGenerator(llm_client)
    elif method == "multi_hop":
        generator = MultiHopGenerator(llm_client)
    elif method == "cot":
        generator = CoTGenerator(llm_client)
    elif method in ["vqa"]:
        generator = VQAGenerator(llm_client)
    else:
        raise ValueError(f"Unsupported generation mode: {method}")

    results = await run_concurrent(
        generator.generate,
        batches,
        desc="[4/4]Generating QAs",
        unit="batch",
        progress_bar=progress_bar,
    )

    # format
    data_format = generation_config["data_format"]
    logger.info("Output data format: %s", data_format)

    results = generator.format_generation_results(
        results, output_data_format=data_format
    )

    return results