Spaces:
Running
Running
File size: 2,947 Bytes
8e67692 9e67c3b 8e67692 9e67c3b 8e67692 9e67c3b 8e67692 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
from collections import defaultdict
import gradio as gr
from graphgen.bases import BaseLLMWrapper
from graphgen.models import JsonKVStorage, NetworkXStorage, QuizGenerator
from graphgen.utils import logger, run_concurrent
async def quiz(
synth_llm_client: BaseLLMWrapper,
graph_storage: NetworkXStorage,
rephrase_storage: JsonKVStorage,
max_samples: int = 1,
progress_bar: gr.Progress = None,
) -> JsonKVStorage:
"""
Get all edges and quiz them using QuizGenerator.
:param synth_llm_client: generate statements
:param graph_storage: graph storage instance
:param rephrase_storage: rephrase storage instance
:param max_samples: max samples for each edge
:param progress_bar
:return:
"""
generator = QuizGenerator(synth_llm_client)
async def _process_single_quiz(item: tuple[str, str, str]):
description, template_type, gt = item
try:
# if rephrase_storage exists already, directly get it
descriptions = rephrase_storage.get_by_id(description)
if descriptions:
return None
prompt = generator.build_prompt_for_description(description, template_type)
new_description = await synth_llm_client.generate_answer(
prompt, temperature=1
)
rephrased_text = generator.parse_rephrased_text(new_description)
return {description: [(rephrased_text, gt)]}
except Exception as e: # pylint: disable=broad-except
logger.error("Error when quizzing description %s: %s", description, e)
return None
edges = graph_storage.get_all_edges()
nodes = graph_storage.get_all_nodes()
results = defaultdict(list)
items = []
for edge in edges:
edge_data = edge[2]
description = edge_data["description"]
results[description] = [(description, "yes")]
for i in range(max_samples):
if i > 0:
items.append((description, "TEMPLATE", "yes"))
items.append((description, "ANTI_TEMPLATE", "no"))
for node in nodes:
node_data = node[1]
description = node_data["description"]
results[description] = [(description, "yes")]
for i in range(max_samples):
if i > 0:
items.append((description, "TEMPLATE", "yes"))
items.append((description, "ANTI_TEMPLATE", "no"))
quiz_results = await run_concurrent(
_process_single_quiz,
items,
desc="Quizzing descriptions",
unit="description",
progress_bar=progress_bar,
)
for new_result in quiz_results:
if new_result:
for key, value in new_result.items():
results[key].extend(value)
for key, value in results.items():
results[key] = list(set(value))
rephrase_storage.upsert({key: results[key]})
return rephrase_storage
|