Spaces:
Running
Running
File size: 4,290 Bytes
799ac7c 0b9d8c7 799ac7c 0b9d8c7 799ac7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from typing import Any
from graphgen.bases import BaseGenerator
from graphgen.templates import COT_GENERATION_PROMPT
from graphgen.utils import compute_content_hash, detect_main_language, logger
class CoTGenerator(BaseGenerator):
@staticmethod
def build_prompt(
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
"""
Build prompts for COT Template Design.
:param batch:
:return:
"""
nodes, edges = batch
entities_str = "\n".join(
[
f"{index + 1}. {node[0]}: {node[1]['description']}"
for index, node in enumerate(nodes)
]
)
relationships_str = "\n".join(
[
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
for index, edge in enumerate(edges)
]
)
language = detect_main_language(entities_str + relationships_str)
prompt = COT_GENERATION_PROMPT[language]["COT_TEMPLATE_DESIGN"].format(
entities=entities_str, relationships=relationships_str
)
return prompt
@staticmethod
def build_prompt_for_cot_generation(
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]],
question: str,
reasoning_path: str,
) -> str:
"""
Build prompts for COT Generation.
"""
nodes, edges = batch
entities_str = "\n".join(
[
f"{index + 1}. {node[0]}: {node[1]['description']}"
for index, node in enumerate(nodes)
]
)
relationships_str = "\n".join(
[
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
for index, edge in enumerate(edges)
]
)
language = detect_main_language(entities_str + relationships_str)
prompt = COT_GENERATION_PROMPT[language]["COT_GENERATION"].format(
entities=entities_str,
relationships=relationships_str,
question=question,
reasoning_template=reasoning_path,
)
return prompt
@staticmethod
def parse_response(response: str) -> dict:
if "Question:" in response and "Reasoning-Path Design:" in response:
question = (
response.split("Question:")[1]
.split("Reasoning-Path Design:")[0]
.strip()
)
reasoning_path = response.split("Reasoning-Path Design:")[1].strip()
elif "问题:" in response and "推理路径设计:" in response:
question = response.split("问题:")[1].split("推理路径设计:")[0].strip()
reasoning_path = response.split("推理路径设计:")[1].strip()
else:
logger.warning("Failed to parse CoT template: %s", response)
return {}
question = question.strip('"')
reasoning_path = reasoning_path.strip('"')
logger.debug("CoT Question: %s", question)
logger.debug("CoT Reasoning Path: %s", reasoning_path)
return {
"question": question,
"reasoning_path": reasoning_path,
}
async def generate(
self,
batch: tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
],
) -> dict[str, Any]:
"""
Generate QAs based on a given batch.
:param batch
:return: QA pairs
"""
result = {}
prompt = self.build_prompt(batch)
response = await self.llm_client.generate_answer(prompt)
response = self.parse_response(response)
question, reasoning_path = response["question"], response["reasoning_path"]
prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
cot_answer = await self.llm_client.generate_answer(prompt)
logger.debug("CoT Answer: %s", cot_answer)
qa_pairs = {
compute_content_hash(question): {
"question": question,
"answer": cot_answer,
"reasoning_path": reasoning_path,
}
}
result.update(qa_pairs)
return result
|