Spaces:
Sleeping
Sleeping
| from template import TEMPLATE_v1, TEMPLATE_v2, TEMPLATE_v3, QUESTIONS | |
| import json | |
| class Prompt: | |
| def __init__(self) -> None: | |
| # self.questions = QUESTIONS | |
| self.template_v1 = TEMPLATE_v1 | |
| self.template_v2 = TEMPLATE_v2 | |
| self.template_v3 = TEMPLATE_v3 | |
| self.version = "v3" | |
| def combine_questions(self, questions): | |
| questions = [ f'Question {id_ +1 }: {q}' for id_, q in enumerate(questions) if 'Input question' not in q] | |
| questions = '\n'.join(questions) | |
| return questions | |
| def _get_v1(self, input, questions): | |
| questions = self.combine_questions(questions) | |
| return self.template_v1.format(input, self.questions) | |
| def _get_v2(self, input, questions): | |
| questions = self.combine_questions(questions) | |
| return self.template_v2.format(input, self.questions) | |
| def _get_v3(self, input, questions): | |
| return self.template_v3.format(input) | |
| def get(self, input, questions, version = None): | |
| self.version = version if version else self.version | |
| if self.version == 'v1': | |
| return self._get_v1(input, questions) | |
| elif self.version == 'v2': | |
| return self._get_v2(input, questions) | |
| elif self.version == 'v3': | |
| return self._get_v3(input, questions) | |
| else: | |
| raise ValueError('Version should be one of {v1, v2, v3}') | |
| def _process_v1(self, res): | |
| res = json.loads(res) | |
| return res | |
| def _process_v2(self, res): | |
| res = json.loads(res) | |
| return res | |
| def _process_v3(self, x): | |
| x = json.loads(x) | |
| res = {} | |
| question_id = 0 | |
| for k, v in x.items(): | |
| if 'answer' in v: | |
| question_id += 1 | |
| question_name = f'Question {question_id}' | |
| res_tmp = {"answer": v['answer'], "original sentences": v['original sentences']} | |
| res[question_name] = res_tmp | |
| else: | |
| k_1, k_2 = v.keys() | |
| in_1 = v[k_1] | |
| in_2 = v[k_2] | |
| question_id += 1 | |
| question_name = f'Question {question_id}' | |
| res_tmp_1 = {"answer": in_1['answer'], "original sentences": in_1['original sentences']} | |
| res[question_name] = res_tmp_1 | |
| question_id += 1 | |
| question_name = f'Question {question_id}' | |
| res_tmp_2 = {"answer": in_2['answer'], "original sentences": in_2['original sentences']} | |
| res[question_name] = res_tmp_2 | |
| return res | |
| def process_result(self, result, version = None): | |
| if not version is None and self.version != version: | |
| self.version = version | |
| print(f'Version changed to {version}') | |
| if version == 'v1': | |
| result = self._process_v1(result) | |
| return result | |
| elif version == 'v2': | |
| result = self._process_v2(result) | |
| return result | |
| elif version == 'v3': | |
| result = self._process_v3(result) | |
| return result | |
| else: | |
| raise ValueError('Version should be one of {v1, v2, v3}') | |