Spaces:
Running
Running
File size: 4,808 Bytes
e4316f1 0b9d8c7 e4316f1 0b9d8c7 e4316f1 0b9d8c7 e4316f1 0b9d8c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from typing import Any
from graphgen.bases import BaseGenerator
from graphgen.templates import VQA_GENERATION_PROMPT
from graphgen.utils import compute_content_hash, detect_main_language, logger
class VQAGenerator(BaseGenerator):
@staticmethod
def build_prompt(
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
nodes, edges = batch
entities_str = "\n".join(
[
f"{index + 1}. {node[0]}: {node[1]['description']}"
for index, node in enumerate(nodes)
]
)
relationships_str = "\n".join(
[
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
for index, edge in enumerate(edges)
]
)
language = detect_main_language(entities_str + relationships_str)
prompt = VQA_GENERATION_PROMPT[language].format(
entities=entities_str, relationships=relationships_str
)
return prompt
@staticmethod
def parse_response(response: str) -> Any:
"""
Parse the LLM response and return the generated QAs
:param response
:return: QA pairs
"""
qa_pairs = {}
qa_list = response.strip().split("\n\n")
for qa in qa_list:
if "Question:" in qa and "Answer:" in qa:
question = qa.split("Question:")[1].split("Answer:")[0].strip()
answer = qa.split("Answer:")[1].strip()
elif "问题:" in qa and "答案:" in qa:
question = qa.split("问题:")[1].split("答案:")[0].strip()
answer = qa.split("答案:")[1].strip()
else:
logger.error("Failed to parse QA pair: %s", qa)
continue
question = question.strip('"')
answer = answer.strip('"')
logger.debug("Question: %s", question)
logger.debug("Answer: %s", answer)
qa_pairs[compute_content_hash(question)] = {
"question": question,
"answer": answer,
}
return qa_pairs
async def generate(
self,
batch: tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
],
) -> dict[str, Any]:
"""
Generate QAs based on a given batch.
:param batch
:return: QA pairs
"""
result = {}
prompt = self.build_prompt(batch)
response = await self.llm_client.generate_answer(prompt)
qa_pairs = self.parse_response(response) # generate one or more QA pairs
nodes, _ = batch
for node in nodes:
node_data = node[1]
if "images" in node_data and node_data["images"]:
img_path = node_data["images"]["img_path"]
for qa in qa_pairs.values():
qa["img_path"] = img_path
result.update(qa_pairs)
return result
@staticmethod
def format_generation_results(
results: list[dict], output_data_format: str
) -> list[dict[str, Any]]:
if output_data_format == "Alpaca":
results = [
{
"instruction": v["question"],
"input": "",
"output": v["answer"],
"image": v.get("img_path", ""),
}
for item in results
for k, v in item.items()
]
elif output_data_format == "Sharegpt":
results = [
{
"conversations": [
{
"from": "human",
"value": [
{"text": v["question"], "image": v.get("img_path", "")}
],
},
{"from": "gpt", "value": v["answer"]},
]
}
for item in results
for k, v in item.items()
]
elif output_data_format == "ChatML":
results = [
{
"messages": [
{
"role": "user",
"content": [
{"text": v["question"], "image": v.get("img_path", "")}
],
},
{"role": "assistant", "content": v["answer"]},
]
}
for item in results
for k, v in item.items()
]
else:
raise ValueError(f"Unknown output data format: {output_data_format}")
return results
|