File size: 1,777 Bytes
799ac7c
 
8e67692
 
d02622b
799ac7c
 
 
 
 
e4316f1
799ac7c
 
 
 
 
d02622b
799ac7c
 
 
 
 
 
8e67692
799ac7c
 
 
 
 
 
2a0edfe
799ac7c
 
f1eedd1
 
799ac7c
f1eedd1
799ac7c
f1eedd1
799ac7c
f1eedd1
799ac7c
f1eedd1
799ac7c
f1eedd1
e4316f1
799ac7c
f1eedd1
799ac7c
 
 
 
 
 
2a0edfe
799ac7c
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Any

import gradio as gr

from graphgen.bases import BaseLLMWrapper
from graphgen.models import (
    AggregatedGenerator,
    AtomicGenerator,
    CoTGenerator,
    MultiHopGenerator,
    VQAGenerator,
)
from graphgen.utils import logger, run_concurrent


async def generate_qas(
    llm_client: BaseLLMWrapper,
    batches: list[
        tuple[
            list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
        ]
    ],
    generation_config: dict,
    progress_bar: gr.Progress = None,
) -> list[dict[str, Any]]:
    """
    Generate question-answer pairs based on nodes and edges.
    :param llm_client: LLM client
    :param batches
    :param generation_config
    :param progress_bar
    :return: QA pairs
    """
    method = generation_config["method"]
    logger.info("[Generation] mode: %s, batches: %d", method, len(batches))

    if method == "atomic":
        generator = AtomicGenerator(llm_client)
    elif method == "aggregated":
        generator = AggregatedGenerator(llm_client)
    elif method == "multi_hop":
        generator = MultiHopGenerator(llm_client)
    elif method == "cot":
        generator = CoTGenerator(llm_client)
    elif method in ["vqa"]:
        generator = VQAGenerator(llm_client)
    else:
        raise ValueError(f"Unsupported generation mode: {method}")

    results = await run_concurrent(
        generator.generate,
        batches,
        desc="[4/4]Generating QAs",
        unit="batch",
        progress_bar=progress_bar,
    )

    # format
    data_format = generation_config["data_format"]
    logger.info("Output data format: %s", data_format)

    results = generator.format_generation_results(
        results, output_data_format=data_format
    )

    return results