Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| class GroundednessChecker: | |
| def __init__(self, model_path="./grounding_detector"): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model.to(self.device) | |
| def check(self, question: str, answer: str, context: str) -> dict: | |
| """Check if answer is grounded in context""" | |
| inputs = self.tokenizer( | |
| question, | |
| answer + " [SEP] " + context, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| return { | |
| "is_grounded": bool(torch.argmax(probs)), | |
| "confidence": probs[0][1].item(), | |
| "details": { | |
| "question": question, | |
| "answer": answer, | |
| "context_snippet": context[:200] + "..." if len(context) > 200 else context | |
| } | |
| } | |
| # Usage Example | |
| if __name__ == "__main__": | |
| # Initialize checker | |
| checker = GroundednessChecker() | |
| # Example from banking PDS | |
| context = """ | |
| Premium Savings Account Terms: | |
| - Annual Percentage Yield (APY): 4.25% | |
| - Minimum opening deposit: $1,000 | |
| - Monthly maintenance fee: $5 (waived if daily balance >= $1,000) | |
| - Maximum withdrawals: 6 per month | |
| """ | |
| # Grounded example | |
| grounded_result = checker.check( | |
| question="What is the minimum opening deposit?", | |
| answer="$1,000", | |
| context=context | |
| ) | |
| print("Grounded Result:", grounded_result) | |
| # Ungrounded example | |
| ungrounded_result = checker.check( | |
| question="What is the monthly maintenance fee?", | |
| answer="$10 monthly charge", | |
| context=context | |
| ) | |
| print("Ungrounded Result:", ungrounded_result) | |