Ramon Meffert
commited on
Commit
·
1f08ed2
1
Parent(s):
a1746cf
Add query cli w/ argparse
Browse files
query.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
import transformers
|
| 4 |
+
|
| 5 |
+
from typing import List
|
| 6 |
+
from datasets import load_dataset, DatasetDict
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
from src.readers.dpr_reader import DprReader
|
| 10 |
+
from src.retrievers.base_retriever import Retriever
|
| 11 |
+
from src.retrievers.es_retriever import ESRetriever
|
| 12 |
+
from src.retrievers.faiss_retriever import FaissRetriever
|
| 13 |
+
from src.utils.preprocessing import result_to_reader_input
|
| 14 |
+
from src.utils.log import get_logger
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_retriever(r: str, ds: DatasetDict) -> Retriever:
|
| 18 |
+
retriever = ESRetriever if r == "es" else FaissRetriever
|
| 19 |
+
return retriever(ds)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def print_name(contexts: dict, section: str, id: int):
|
| 23 |
+
name = contexts[section][id]
|
| 24 |
+
if name != 'nan':
|
| 25 |
+
print(f" {section}: {name}")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def print_answers(answers: List[tuple], scores: List[float], contexts: dict):
|
| 29 |
+
# calculate answer scores
|
| 30 |
+
sm = torch.nn.Softmax(dim=0)
|
| 31 |
+
d_scores = sm(torch.Tensor(
|
| 32 |
+
[pred.relevance_score for pred in answers]))
|
| 33 |
+
s_scores = sm(torch.Tensor(
|
| 34 |
+
[pred.span_score for pred in answers]))
|
| 35 |
+
|
| 36 |
+
for pos, answer in enumerate(answers):
|
| 37 |
+
print(f"{pos + 1:>4}. {answer.text}")
|
| 38 |
+
print(f" {'-' * len(answer.text)}")
|
| 39 |
+
print_name(contexts, 'chapter', answer.doc_id)
|
| 40 |
+
print_name(contexts, 'section', answer.doc_id)
|
| 41 |
+
print_name(contexts, 'subsection', answer.doc_id)
|
| 42 |
+
print(f" retrieval score: {scores[answer.doc_id]:6.02f}%")
|
| 43 |
+
print(f" document score: {d_scores[pos] * 100:6.02f}%")
|
| 44 |
+
print(f" span score: {s_scores[pos] * 100:6.02f}%")
|
| 45 |
+
print()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def main(args: argparse.Namespace):
|
| 49 |
+
# Initialize dataset
|
| 50 |
+
dataset = load_dataset("GroNLP/ik-nlp-22_slp")
|
| 51 |
+
|
| 52 |
+
# Retrieve
|
| 53 |
+
retriever = get_retriever(args.retriever, dataset)
|
| 54 |
+
scores, contexts = retriever.retrieve(args.query)
|
| 55 |
+
|
| 56 |
+
# Read
|
| 57 |
+
reader = DprReader()
|
| 58 |
+
reader_input = result_to_reader_input(contexts)
|
| 59 |
+
answers = reader.read(args.query, reader_input, num_answers=args.top)
|
| 60 |
+
|
| 61 |
+
# Print output
|
| 62 |
+
print_answers(answers, scores, contexts)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
# Setup environment
|
| 67 |
+
load_dotenv()
|
| 68 |
+
logger = get_logger()
|
| 69 |
+
transformers.logging.set_verbosity_error()
|
| 70 |
+
|
| 71 |
+
# Set up CLI arguments
|
| 72 |
+
parser = argparse.ArgumentParser(
|
| 73 |
+
formatter_class=argparse.MetavarTypeHelpFormatter
|
| 74 |
+
)
|
| 75 |
+
parser.add_argument("query", type=str,
|
| 76 |
+
help="The question to feed to the QA system")
|
| 77 |
+
parser.add_argument("--top", "-t", type=int, default=1,
|
| 78 |
+
help="The number of answers to retrieve")
|
| 79 |
+
parser.add_argument("--retriever", "-r", type=str.lower,
|
| 80 |
+
choices=["faiss", "es"], default="faiss",
|
| 81 |
+
help="The retrieval method to use")
|
| 82 |
+
|
| 83 |
+
args = parser.parse_args()
|
| 84 |
+
main(args)
|