File size: 4,079 Bytes
acd7cf4
 
 
 
3a3b216
d02622b
 
acd7cf4
3a3b216
acd7cf4
 
 
d02622b
3a3b216
 
 
 
 
acd7cf4
 
 
 
 
 
 
 
 
 
 
 
 
3a3b216
acd7cf4
 
 
 
 
 
 
 
3a3b216
acd7cf4
3a3b216
acd7cf4
3a3b216
acd7cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
3a3b216
acd7cf4
 
 
 
3a3b216
 
 
 
 
 
 
acd7cf4
3a3b216
 
 
 
 
 
 
 
 
acd7cf4
 
 
 
 
 
3a3b216
acd7cf4
 
 
 
3a3b216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acd7cf4
3a3b216
acd7cf4
 
3a3b216
acd7cf4
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import asyncio
from collections import defaultdict

from tqdm.asyncio import tqdm as tqdm_async

from graphgen.bases import BaseLLMWrapper
from graphgen.models import JsonKVStorage, NetworkXStorage
from graphgen.templates import DESCRIPTION_REPHRASING_PROMPT
from graphgen.utils import detect_main_language, logger


async def quiz(
    synth_llm_client: BaseLLMWrapper,
    graph_storage: NetworkXStorage,
    rephrase_storage: JsonKVStorage,
    max_samples: int = 1,
    max_concurrent: int = 1000,
) -> JsonKVStorage:
    """
    Get all edges and quiz them

    :param synth_llm_client: generate statements
    :param graph_storage: graph storage instance
    :param rephrase_storage: rephrase storage instance
    :param max_samples: max samples for each edge
    :param max_concurrent: max concurrent
    :return:
    """

    semaphore = asyncio.Semaphore(max_concurrent)

    async def _process_single_quiz(des: str, prompt: str, gt: str):
        async with semaphore:
            try:
                # 如果在rephrase_storage中已经存在,直接取出
                descriptions = await rephrase_storage.get_by_id(des)
                if descriptions:
                    return None

                new_description = await synth_llm_client.generate_answer(
                    prompt, temperature=1
                )
                return {des: [(new_description, gt)]}

            except Exception as e:  # pylint: disable=broad-except
                logger.error("Error when quizzing description %s: %s", des, e)
                return None

    edges = await graph_storage.get_all_edges()
    nodes = await graph_storage.get_all_nodes()

    results = defaultdict(list)
    tasks = []
    for edge in edges:
        edge_data = edge[2]

        description = edge_data["description"]
        language = "English" if detect_main_language(description) == "en" else "Chinese"

        results[description] = [(description, "yes")]

        for i in range(max_samples):
            if i > 0:
                tasks.append(
                    _process_single_quiz(
                        description,
                        DESCRIPTION_REPHRASING_PROMPT[language]["TEMPLATE"].format(
                            input_sentence=description
                        ),
                        "yes",
                    )
                )
            tasks.append(
                _process_single_quiz(
                    description,
                    DESCRIPTION_REPHRASING_PROMPT[language]["ANTI_TEMPLATE"].format(
                        input_sentence=description
                    ),
                    "no",
                )
            )

    for node in nodes:
        node_data = node[1]
        description = node_data["description"]
        language = "English" if detect_main_language(description) == "en" else "Chinese"

        results[description] = [(description, "yes")]

        for i in range(max_samples):
            if i > 0:
                tasks.append(
                    _process_single_quiz(
                        description,
                        DESCRIPTION_REPHRASING_PROMPT[language]["TEMPLATE"].format(
                            input_sentence=description
                        ),
                        "yes",
                    )
                )
            tasks.append(
                _process_single_quiz(
                    description,
                    DESCRIPTION_REPHRASING_PROMPT[language]["ANTI_TEMPLATE"].format(
                        input_sentence=description
                    ),
                    "no",
                )
            )

    for result in tqdm_async(
        asyncio.as_completed(tasks), total=len(tasks), desc="Quizzing descriptions"
    ):
        new_result = await result
        if new_result:
            for key, value in new_result.items():
                results[key].extend(value)

    for key, value in results.items():
        results[key] = list(set(value))
        await rephrase_storage.upsert({key: results[key]})

    return rephrase_storage