Commit
·
b06298d
1
Parent(s):
615dee0
add experiment code
Browse files- README.old.md +3 -3
- main.py +95 -59
- src/retrievers/base_retriever.py +11 -2
- src/retrievers/es_retriever.py +8 -4
- src/retrievers/faiss_retriever.py +4 -2
- test.py +20 -0
README.old.md
CHANGED
|
@@ -6,12 +6,12 @@
|
|
| 6 |
- [ ] Formules enzo eruit filteren
|
| 7 |
- [ ] Splitsen op zinnen...?
|
| 8 |
- [ ] Meer language models proberen
|
| 9 |
-
- [
|
| 10 |
-
- [
|
| 11 |
|
| 12 |
### Extra dingen
|
| 13 |
|
| 14 |
-
- [
|
| 15 |
- [ ] Question generation voor finetuning
|
| 16 |
- [ ] Language model finetunen
|
| 17 |
|
|
|
|
| 6 |
- [ ] Formules enzo eruit filteren
|
| 7 |
- [ ] Splitsen op zinnen...?
|
| 8 |
- [ ] Meer language models proberen
|
| 9 |
+
- [X] Elasticsearch
|
| 10 |
+
- [X] CLI voor vragen beantwoorden
|
| 11 |
|
| 12 |
### Extra dingen
|
| 13 |
|
| 14 |
+
- [X] Huggingface spaces demo
|
| 15 |
- [ ] Question generation voor finetuning
|
| 16 |
- [ ] Language model finetunen
|
| 17 |
|
main.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
| 1 |
-
import os
|
| 2 |
import random
|
| 3 |
-
from typing import cast
|
| 4 |
-
import time
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import transformers
|
| 8 |
from datasets import DatasetDict, load_dataset
|
| 9 |
from dotenv import load_dotenv
|
|
|
|
| 10 |
|
| 11 |
from src.evaluation import evaluate
|
| 12 |
from src.readers.dpr_reader import DprReader
|
|
|
|
| 13 |
from src.retrievers.es_retriever import ESRetriever
|
| 14 |
from src.retrievers.faiss_retriever import FaissRetriever
|
| 15 |
from src.utils.log import get_logger
|
| 16 |
from src.utils.preprocessing import context_to_reader_input
|
|
|
|
| 17 |
|
| 18 |
logger = get_logger()
|
| 19 |
|
|
@@ -26,62 +27,97 @@ if __name__ == '__main__':
|
|
| 26 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
| 27 |
questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
#
|
| 66 |
-
|
| 67 |
-
#
|
| 68 |
-
#
|
| 69 |
-
#
|
| 70 |
-
|
| 71 |
-
#
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
#
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
# Calculate overall performance
|
| 87 |
# total_f1 = 0
|
|
|
|
|
|
|
| 1 |
import random
|
| 2 |
+
from typing import Dict, cast
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import transformers
|
| 6 |
from datasets import DatasetDict, load_dataset
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
+
from query import print_answers
|
| 9 |
|
| 10 |
from src.evaluation import evaluate
|
| 11 |
from src.readers.dpr_reader import DprReader
|
| 12 |
+
from src.retrievers.base_retriever import Retriever
|
| 13 |
from src.retrievers.es_retriever import ESRetriever
|
| 14 |
from src.retrievers.faiss_retriever import FaissRetriever
|
| 15 |
from src.utils.log import get_logger
|
| 16 |
from src.utils.preprocessing import context_to_reader_input
|
| 17 |
+
from src.utils.timing import get_times, timeit
|
| 18 |
|
| 19 |
logger = get_logger()
|
| 20 |
|
|
|
|
| 27 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
| 28 |
questions = cast(DatasetDict, load_dataset(dataset_name, "questions"))
|
| 29 |
|
| 30 |
+
# Only doing a few questions for speed
|
| 31 |
+
subset_idx = 3
|
| 32 |
+
questions_test = questions["test"][:subset_idx]
|
| 33 |
+
|
| 34 |
+
experiments: Dict[str, Retriever] = {
|
| 35 |
+
"faiss": FaissRetriever(paragraphs),
|
| 36 |
+
# "es": ESRetriever(paragraphs),
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
for experiment_name, retriever in experiments.items():
|
| 40 |
+
reader = DprReader()
|
| 41 |
+
|
| 42 |
+
for idx in range(subset_idx):
|
| 43 |
+
question = questions_test["question"][idx]
|
| 44 |
+
answer = questions_test["answer"][idx]
|
| 45 |
+
|
| 46 |
+
scores, context = retriever.retrieve(question, 5)
|
| 47 |
+
reader_input = context_to_reader_input(context)
|
| 48 |
+
|
| 49 |
+
# workaround so we can use the decorator with a dynamic name for time recording
|
| 50 |
+
time_wrapper = timeit(f"{experiment_name}.read")
|
| 51 |
+
answers = time_wrapper(reader.read)(question, reader_input, 5)
|
| 52 |
+
|
| 53 |
+
# Calculate softmaxed scores for readable output
|
| 54 |
+
sm = torch.nn.Softmax(dim=0)
|
| 55 |
+
document_scores = sm(torch.Tensor(
|
| 56 |
+
[pred.relevance_score for pred in answers]))
|
| 57 |
+
span_scores = sm(torch.Tensor(
|
| 58 |
+
[pred.span_score for pred in answers]))
|
| 59 |
+
|
| 60 |
+
print_answers(answers, scores, context)
|
| 61 |
+
|
| 62 |
+
# TODO evaluation and storing of results
|
| 63 |
+
|
| 64 |
+
times = get_times()
|
| 65 |
+
print(times)
|
| 66 |
+
# TODO evaluation and storing of results
|
| 67 |
+
|
| 68 |
+
# # Initialize retriever
|
| 69 |
+
# retriever = FaissRetriever(paragraphs)
|
| 70 |
+
# # retriever = ESRetriever(paragraphs)
|
| 71 |
+
|
| 72 |
+
# # Retrieve example
|
| 73 |
+
# # random.seed(111)
|
| 74 |
+
# random_index = random.randint(0, len(questions_test["question"])-1)
|
| 75 |
+
# example_q = questions_test["question"][random_index]
|
| 76 |
+
# example_a = questions_test["answer"][random_index]
|
| 77 |
+
|
| 78 |
+
# scores, result = retriever.retrieve(example_q)
|
| 79 |
+
# reader_input = context_to_reader_input(result)
|
| 80 |
+
|
| 81 |
+
# # TODO: use new code from query.py to clean this up
|
| 82 |
+
# # Initialize reader
|
| 83 |
+
# answers = reader.read(example_q, reader_input)
|
| 84 |
+
|
| 85 |
+
# # Calculate softmaxed scores for readable output
|
| 86 |
+
# sm = torch.nn.Softmax(dim=0)
|
| 87 |
+
# document_scores = sm(torch.Tensor(
|
| 88 |
+
# [pred.relevance_score for pred in answers]))
|
| 89 |
+
# span_scores = sm(torch.Tensor(
|
| 90 |
+
# [pred.span_score for pred in answers]))
|
| 91 |
+
|
| 92 |
+
# print(example_q)
|
| 93 |
+
# for answer_i, answer in enumerate(answers):
|
| 94 |
+
# print(f"[{answer_i + 1}]: {answer.text}")
|
| 95 |
+
# print(f"\tDocument {answer.doc_id}", end='')
|
| 96 |
+
# print(f"\t(score {document_scores[answer_i] * 100:.02f})")
|
| 97 |
+
# print(f"\tSpan {answer.start_index}-{answer.end_index}", end='')
|
| 98 |
+
# print(f"\t(score {span_scores[answer_i] * 100:.02f})")
|
| 99 |
+
# print() # Newline
|
| 100 |
+
|
| 101 |
+
# # print(f"Example q: {example_q} answer: {result['text'][0]}")
|
| 102 |
+
|
| 103 |
+
# # for i, score in enumerate(scores):
|
| 104 |
+
# # print(f"Result {i+1} (score: {score:.02f}):")
|
| 105 |
+
# # print(result['text'][i])
|
| 106 |
+
|
| 107 |
+
# # Determine best answer we want to evaluate
|
| 108 |
+
# highest, highest_index = 0, 0
|
| 109 |
+
# for i, value in enumerate(span_scores):
|
| 110 |
+
# if value + document_scores[i] > highest:
|
| 111 |
+
# highest = value + document_scores[i]
|
| 112 |
+
# highest_index = i
|
| 113 |
+
|
| 114 |
+
# # Retrieve exact match and F1-score
|
| 115 |
+
# exact_match, f1_score = evaluate(
|
| 116 |
+
# example_a, answers[highest_index].text)
|
| 117 |
+
# print(f"Gold answer: {example_a}\n"
|
| 118 |
+
# f"Predicted answer: {answers[highest_index].text}\n"
|
| 119 |
+
# f"Exact match: {exact_match:.02f}\n"
|
| 120 |
+
# f"F1-score: {f1_score:.02f}")
|
| 121 |
|
| 122 |
# Calculate overall performance
|
| 123 |
# total_f1 = 0
|
src/retrievers/base_retriever.py
CHANGED
|
@@ -1,3 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
class Retriever():
|
| 2 |
-
def retrieve(self, query: str, k: int):
|
| 3 |
-
|
|
|
|
| 1 |
+
from typing import Dict, List, Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
RetrieveTypeResult = Dict[str, List[str]]
|
| 6 |
+
RetrieveTypeScores = np.ndarray
|
| 7 |
+
RetrieveType = Tuple[RetrieveTypeScores, RetrieveTypeResult]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
class Retriever():
|
| 11 |
+
def retrieve(self, query: str, k: int) -> RetrieveType:
|
| 12 |
+
raise NotImplementedError()
|
src/retrievers/es_retriever.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
|
|
|
|
|
| 1 |
from datasets import DatasetDict
|
| 2 |
-
from src.utils.log import get_logger
|
| 3 |
-
from src.retrievers.base_retriever import Retriever
|
| 4 |
from elasticsearch import Elasticsearch
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
logger = get_logger()
|
| 8 |
|
|
@@ -31,5 +34,6 @@ class ESRetriever(Retriever):
|
|
| 31 |
es_index_name="paragraphs",
|
| 32 |
es_client=self.client)
|
| 33 |
|
| 34 |
-
|
|
|
|
| 35 |
return self.paragraphs.get_nearest_examples("paragraphs", query, k)
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
from datasets import DatasetDict
|
|
|
|
|
|
|
| 4 |
from elasticsearch import Elasticsearch
|
| 5 |
+
|
| 6 |
+
from src.retrievers.base_retriever import RetrieveType, Retriever
|
| 7 |
+
from src.utils.log import get_logger
|
| 8 |
+
from src.utils.timing import timeit
|
| 9 |
|
| 10 |
logger = get_logger()
|
| 11 |
|
|
|
|
| 34 |
es_index_name="paragraphs",
|
| 35 |
es_client=self.client)
|
| 36 |
|
| 37 |
+
@timeit("esretriever.retrieve")
|
| 38 |
+
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
|
| 39 |
return self.paragraphs.get_nearest_examples("paragraphs", query, k)
|
src/retrievers/faiss_retriever.py
CHANGED
|
@@ -10,9 +10,10 @@ from transformers import (
|
|
| 10 |
DPRQuestionEncoderTokenizer,
|
| 11 |
)
|
| 12 |
|
| 13 |
-
from src.retrievers.base_retriever import Retriever
|
| 14 |
from src.utils.log import get_logger
|
| 15 |
from src.utils.preprocessing import remove_formulas
|
|
|
|
| 16 |
|
| 17 |
# Hacky fix for FAISS error on macOS
|
| 18 |
# See https://stackoverflow.com/a/63374568/4545692
|
|
@@ -83,7 +84,8 @@ class FaissRetriever(Retriever):
|
|
| 83 |
|
| 84 |
return index
|
| 85 |
|
| 86 |
-
|
|
|
|
| 87 |
def embed(q):
|
| 88 |
# Inline helper function to perform embedding
|
| 89 |
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
|
|
|
|
| 10 |
DPRQuestionEncoderTokenizer,
|
| 11 |
)
|
| 12 |
|
| 13 |
+
from src.retrievers.base_retriever import RetrieveType, Retriever
|
| 14 |
from src.utils.log import get_logger
|
| 15 |
from src.utils.preprocessing import remove_formulas
|
| 16 |
+
from src.utils.timing import timeit
|
| 17 |
|
| 18 |
# Hacky fix for FAISS error on macOS
|
| 19 |
# See https://stackoverflow.com/a/63374568/4545692
|
|
|
|
| 84 |
|
| 85 |
return index
|
| 86 |
|
| 87 |
+
@timeit("faissretriever.retrieve")
|
| 88 |
+
def retrieve(self, query: str, k: int = 5) -> RetrieveType:
|
| 89 |
def embed(q):
|
| 90 |
# Inline helper function to perform embedding
|
| 91 |
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
|
test.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from src.retrievers.faiss_retriever import FaissRetriever
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
data = load_dataset("GroNLP/ik-nlp-22_slp", "paragraphs")
|
| 7 |
+
|
| 8 |
+
# # %%
|
| 9 |
+
# x = data["test"][:3]
|
| 10 |
+
|
| 11 |
+
# # %%
|
| 12 |
+
# for y in x:
|
| 13 |
+
|
| 14 |
+
# print(y)
|
| 15 |
+
# # %%
|
| 16 |
+
# x.num_rows
|
| 17 |
+
|
| 18 |
+
# # %%
|
| 19 |
+
retriever = FaissRetriever(data)
|
| 20 |
+
scores, result = retriever.retrieve("hello world")
|