Spaces:
Running
Running
File size: 2,574 Bytes
799ac7c d02622b 799ac7c d02622b 8c66169 799ac7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from abc import ABC, abstractmethod
from typing import Any
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
class BaseGenerator(ABC):
"""
Generate QAs based on given prompts.
"""
def __init__(self, llm_client: BaseLLMWrapper):
self.llm_client = llm_client
@staticmethod
@abstractmethod
def build_prompt(
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
) -> str:
"""Build prompt for LLM based on the given batch"""
@staticmethod
@abstractmethod
def parse_response(response: str) -> Any:
"""Parse the LLM response and return the generated QAs"""
async def generate(
self,
batch: tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
],
) -> dict[str, Any]:
"""
Generate QAs based on a given batch.
:param batch
:return: QA pairs
"""
result = {}
prompt = self.build_prompt(batch)
response = await self.llm_client.generate_answer(prompt)
qa_pairs = self.parse_response(response) # generate one or more QA pairs
result.update(qa_pairs)
return result
@staticmethod
def format_generation_results(
results: list[dict], output_data_format: str
) -> list[dict[str, Any]]:
if output_data_format == "Alpaca":
results = [
{
"instruction": v["question"],
"input": "",
"output": v["answer"],
}
for item in results
for k, v in item.items()
]
elif output_data_format == "Sharegpt":
results = [
{
"conversations": [
{"from": "human", "value": v["question"]},
{"from": "gpt", "value": v["answer"]},
]
}
for item in results
for k, v in item.items()
]
elif output_data_format == "ChatML":
results = [
{
"messages": [
{"role": "user", "content": v["question"]},
{"role": "assistant", "content": v["answer"]},
]
}
for item in results
for k, v in item.items()
]
else:
raise ValueError(f"Unknown output data format: {output_data_format}")
return results
|