Spaces:
Sleeping
Sleeping
| import re | |
| import spacy | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| import random | |
| from datetime import datetime, timedelta | |
| from dateutil.parser import parse as parse_date | |
| # A simplistic Ungrounded Answer Generator. | |
| class UngroundedAnswerGenerator: | |
| def __init__(self): | |
| self.nlp = spacy.load("en_core_web_sm") | |
| self.sim_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| # | |
| self.financial_terms = [ | |
| "CommBank Credit Card", | |
| "Personal credit cards", | |
| "Business credit cards", | |
| "PIN", | |
| "ePayments Code", | |
| "Conditions of Use", | |
| "Schedule of Credit Card Particulars", | |
| "Banking Code of Practice", | |
| "NetBank", | |
| "CommBank app", | |
| "Electronic Banking Terms and Conditions", | |
| "Tap & Pay", | |
| "cash advance", | |
| "credit limit", | |
| "ATM cash withdrawals", | |
| "international transaction fee", | |
| "Mastercard", | |
| "Visa", | |
| "balance transfers", | |
| "regular payments", | |
| "additional cardholder", | |
| "digital wallet", | |
| "statements and notices", | |
| "closing balance", | |
| "minimum payment", | |
| "interest-free period on purchases", | |
| "SurePay instalment plan", | |
| "AutoPay", | |
| "fees and interest rates", | |
| "annual interest rates", | |
| "daily interest rate", | |
| "statement period", | |
| "balance transfer period", | |
| "unauthorised transaction", | |
| "card scheme refunds", | |
| "purchase plan", | |
| "card balance plan", | |
| "cash advance balance plan", | |
| "instalment setup fee", | |
| "purchase balance", | |
| "cash advances balance", | |
| "interest rate for the plan", | |
| "credit card account", | |
| "default under your contract" | |
| ] | |
| def generate(self, context: str, answer: str) -> str: | |
| strategy = self._select_strategy(answer) | |
| return strategy(context, answer) | |
| def _select_strategy(self, answer: str): | |
| doc = self.nlp(answer) | |
| ents = [ent.label_ for ent in doc.ents] | |
| if "DATE" in ents: | |
| return self._perturb_dates | |
| if any(e in ["MONEY", "PERCENT"] for e in ents): | |
| return self._perturb_numbers | |
| return self._semantic_distractor | |
| def _perturb_numbers(self, context: str, answer: str) -> str: | |
| if "$" in answer: | |
| base = self._extract_number(answer) | |
| return f"${base * random.uniform(0.8, 1.2):.2f}" | |
| elif "%" in answer: | |
| base = self._extract_number(answer) | |
| return f"{base * random.uniform(0.5, 1.5):.1f}%" | |
| return answer | |
| def _perturb_dates(self, context: str, answer: str) -> str: | |
| try: | |
| dt = parse_date(answer) | |
| if dt: | |
| delta = timedelta(days=random.randint(-30, 30)) | |
| return (dt + delta).strftime("%Y-%m-%d") | |
| except: | |
| pass | |
| return answer | |
| def _semantic_distractor(self, context: str, answer: str) -> str: | |
| answer_emb = self.sim_model.encode(answer) | |
| term_embs = self.sim_model.encode(self.financial_terms) | |
| similarities = np.dot(term_embs, answer_emb) | |
| return self.financial_terms[np.argsort(similarities)[-2]] | |
| def _extract_number(self, text: str) -> float: | |
| try: | |
| return float(re.search(r"\d+\.?\d*", text).group()) | |
| except: | |
| return random.uniform(1, 1000) | |