File size: 2,947 Bytes
8e67692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e67c3b
8e67692
 
 
 
 
 
 
 
 
 
 
 
 
 
9e67c3b
 
8e67692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e67c3b
8e67692
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from collections import defaultdict

import gradio as gr

from graphgen.bases import BaseLLMWrapper
from graphgen.models import JsonKVStorage, NetworkXStorage, QuizGenerator
from graphgen.utils import logger, run_concurrent


async def quiz(
    synth_llm_client: BaseLLMWrapper,
    graph_storage: NetworkXStorage,
    rephrase_storage: JsonKVStorage,
    max_samples: int = 1,
    progress_bar: gr.Progress = None,
) -> JsonKVStorage:
    """
    Get all edges and quiz them using QuizGenerator.

    :param synth_llm_client: generate statements
    :param graph_storage: graph storage instance
    :param rephrase_storage: rephrase storage instance
    :param max_samples: max samples for each edge
    :param progress_bar
    :return:
    """

    generator = QuizGenerator(synth_llm_client)

    async def _process_single_quiz(item: tuple[str, str, str]):
        description, template_type, gt = item
        try:
            # if rephrase_storage exists already, directly get it
            descriptions = rephrase_storage.get_by_id(description)
            if descriptions:
                return None

            prompt = generator.build_prompt_for_description(description, template_type)
            new_description = await synth_llm_client.generate_answer(
                prompt, temperature=1
            )
            rephrased_text = generator.parse_rephrased_text(new_description)
            return {description: [(rephrased_text, gt)]}

        except Exception as e:  # pylint: disable=broad-except
            logger.error("Error when quizzing description %s: %s", description, e)
            return None

    edges = graph_storage.get_all_edges()
    nodes = graph_storage.get_all_nodes()

    results = defaultdict(list)
    items = []
    for edge in edges:
        edge_data = edge[2]
        description = edge_data["description"]

        results[description] = [(description, "yes")]

        for i in range(max_samples):
            if i > 0:
                items.append((description, "TEMPLATE", "yes"))
            items.append((description, "ANTI_TEMPLATE", "no"))

    for node in nodes:
        node_data = node[1]
        description = node_data["description"]

        results[description] = [(description, "yes")]

        for i in range(max_samples):
            if i > 0:
                items.append((description, "TEMPLATE", "yes"))
            items.append((description, "ANTI_TEMPLATE", "no"))

    quiz_results = await run_concurrent(
        _process_single_quiz,
        items,
        desc="Quizzing descriptions",
        unit="description",
        progress_bar=progress_bar,
    )

    for new_result in quiz_results:
        if new_result:
            for key, value in new_result.items():
                results[key].extend(value)

    for key, value in results.items():
        results[key] = list(set(value))
        rephrase_storage.upsert({key: results[key]})

    return rephrase_storage