Ramon Meffert
commited on
Commit
·
83870cc
1
Parent(s):
8bbe3aa
Add base model retriever
Browse files- README.md +48 -0
- main.py → base_model/main.py +7 -6
- base_model/retriever.py +53 -24
- poetry.lock +29 -1
- pyproject.toml +1 -0
README.md
CHANGED
|
@@ -25,3 +25,51 @@ De meeste QA systemen bestaan uit twee onderdelen:
|
|
| 25 |
|
| 26 |
- Huggingface QA tutorial: <https://huggingface.co/docs/transformers/tasks/question_answering#finetune-with-tensorflow>
|
| 27 |
- Overview van open-domain question answering technieken: <https://lilianweng.github.io/posts/2020-10-29-odqa/>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
- Huggingface QA tutorial: <https://huggingface.co/docs/transformers/tasks/question_answering#finetune-with-tensorflow>
|
| 27 |
- Overview van open-domain question answering technieken: <https://lilianweng.github.io/posts/2020-10-29-odqa/>
|
| 28 |
+
|
| 29 |
+
## Base model
|
| 30 |
+
|
| 31 |
+
Tot nu toe alleen een retriever die adhv een vraag de top-k relevante documents
|
| 32 |
+
ophaalt. Haalt voor veel vragen wel hoge similarity scores, maar de documents
|
| 33 |
+
die die ophaalt zijn meestal niet erg relevant.
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
poetry shell
|
| 37 |
+
cd base_model
|
| 38 |
+
poetry run python main.py
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Voorbeeld
|
| 42 |
+
|
| 43 |
+
"What is the perplexity of a language model?"
|
| 44 |
+
|
| 45 |
+
> Result 1 (score: 74.10):
|
| 46 |
+
> Figure 10 .17 A sample alignment between sentences in English and French, with
|
| 47 |
+
> sentences extracted from Antoine de Saint-Exupery's Le Petit Prince and a
|
| 48 |
+
> hypothetical translation. Sentence alignment takes sentences e 1 , ..., e n ,
|
| 49 |
+
> and f 1 , ..., f n and finds minimal > sets of sentences that are translations
|
| 50 |
+
> of each other, including single sentence mappings like (e 1 ,f 1 ), (e 4 -f 3
|
| 51 |
+
> ), (e 5 -f 4 ), (e 6 -f 6 ) as well as 2-1 alignments (e 2 /e 3 ,f 2 ), (e 7
|
| 52 |
+
> /e 8 -f 7 ), and null alignments (f 5 ).
|
| 53 |
+
>
|
| 54 |
+
> Result 2 (score: 74.23):
|
| 55 |
+
> Character or word overlap-based metrics like chrF (or BLEU, or etc.) are
|
| 56 |
+
> mainly used to compare two systems, with the goal of answering questions like:
|
| 57 |
+
> did the new algorithm we just invented improve our MT system? To know if the
|
| 58 |
+
> difference between the chrF scores of two > MT systems is a significant
|
| 59 |
+
> difference, we use the paired bootstrap test, or the similar randomization
|
| 60 |
+
> test.
|
| 61 |
+
>
|
| 62 |
+
> Result 3 (score: 74.43):
|
| 63 |
+
> The model thus predicts the class negative for the test sentence.
|
| 64 |
+
>
|
| 65 |
+
> Result 4 (score: 74.95):
|
| 66 |
+
> Translating from languages with extensive pro-drop, like Chinese or Japanese,
|
| 67 |
+
> to non-pro-drop languages like English can be difficult since the model must
|
| 68 |
+
> somehow identify each zero and recover who or what is being talked about in
|
| 69 |
+
> order to insert the proper pronoun.
|
| 70 |
+
>
|
| 71 |
+
> Result 5 (score: 76.22):
|
| 72 |
+
> Similarly, a recent challenge set, the WinoMT dataset (Stanovsky et al., 2019)
|
| 73 |
+
> shows that MT systems perform worse when they are asked to translate sentences
|
| 74 |
+
> that describe people with non-stereotypical gender roles, like "The doctor
|
| 75 |
+
> asked the nurse to help her in the > operation".
|
main.py → base_model/main.py
RENAMED
|
@@ -1,14 +1,15 @@
|
|
| 1 |
-
from
|
|
|
|
| 2 |
|
| 3 |
if __name__ == '__main__':
|
| 4 |
# Initialize retriever
|
| 5 |
r = Retriever()
|
| 6 |
|
| 7 |
# Retrieve example
|
| 8 |
-
|
| 9 |
-
"
|
| 10 |
|
| 11 |
-
for i,
|
| 12 |
-
print(f"Result {i+1} (score: {score
|
| 13 |
-
print(result['text'][
|
| 14 |
print() # Newline
|
|
|
|
| 1 |
+
from retriever import Retriever
|
| 2 |
+
|
| 3 |
|
| 4 |
if __name__ == '__main__':
|
| 5 |
# Initialize retriever
|
| 6 |
r = Retriever()
|
| 7 |
|
| 8 |
# Retrieve example
|
| 9 |
+
scores, result = r.retrieve(
|
| 10 |
+
"What is the perplexity of a language model?")
|
| 11 |
|
| 12 |
+
for i, score in enumerate(scores):
|
| 13 |
+
print(f"Result {i+1} (score: {score:.02f}):")
|
| 14 |
+
print(result['text'][i])
|
| 15 |
print() # Newline
|
base_model/retriever.py
CHANGED
|
@@ -1,10 +1,21 @@
|
|
| 1 |
-
from transformers import
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from datasets import load_dataset
|
| 4 |
import torch
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
| 8 |
"""A class used to retrieve relevant documents based on some query.
|
| 9 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
| 10 |
"""
|
|
@@ -21,47 +32,64 @@ class Retriever():
|
|
| 21 |
|
| 22 |
# Context encoding and tokenization
|
| 23 |
self.ctx_encoder = DPRContextEncoder.from_pretrained(
|
| 24 |
-
"facebook/dpr-ctx_encoder-single-nq-base"
|
|
|
|
| 25 |
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
|
| 26 |
-
"facebook/dpr-ctx_encoder-single-nq-base"
|
|
|
|
| 27 |
|
| 28 |
# Question encoding and tokenization
|
| 29 |
self.q_encoder = DPRQuestionEncoder.from_pretrained(
|
| 30 |
-
"facebook/dpr-question_encoder-single-nq-base"
|
|
|
|
| 31 |
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
| 32 |
-
"facebook/dpr-question_encoder-single-nq-base"
|
|
|
|
| 33 |
|
| 34 |
# Dataset building
|
| 35 |
self.dataset = self.__init_dataset(dataset)
|
| 36 |
|
| 37 |
-
def __init_dataset(self,
|
|
|
|
|
|
|
| 38 |
"""Loads the dataset and adds FAISS embeddings.
|
| 39 |
|
| 40 |
Args:
|
| 41 |
dataset (str): A HuggingFace dataset name.
|
|
|
|
|
|
|
| 42 |
|
| 43 |
Returns:
|
| 44 |
Dataset: A dataset with a new column 'embeddings' containing FAISS
|
| 45 |
embeddings.
|
| 46 |
"""
|
| 47 |
-
# TODO: save ds w/ embeddings to disk and retrieve it if it already exists
|
| 48 |
-
|
| 49 |
# Load dataset
|
| 50 |
-
ds = load_dataset(dataset, name=
|
| 51 |
|
| 52 |
-
|
| 53 |
-
#
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
ds_with_embeddings.add_faiss_index(column='embeddings')
|
| 64 |
-
return ds_with_embeddings
|
| 65 |
|
| 66 |
def retrieve(self, query: str, k: int = 5):
|
| 67 |
"""Retrieve the top k matches for a search query.
|
|
@@ -77,10 +105,11 @@ class Retriever():
|
|
| 77 |
|
| 78 |
def embed(q):
|
| 79 |
# Inline helper function to perform embedding
|
| 80 |
-
tok = self.q_tokenizer(q, return_tensors=
|
| 81 |
return self.q_encoder(**tok)[0][0].numpy()
|
| 82 |
|
| 83 |
question_embedding = embed(query)
|
| 84 |
scores, results = self.dataset.get_nearest_examples(
|
| 85 |
-
|
|
|
|
| 86 |
return scores, results
|
|
|
|
| 1 |
+
from transformers import (
|
| 2 |
+
DPRContextEncoder,
|
| 3 |
+
DPRContextEncoderTokenizer,
|
| 4 |
+
DPRQuestionEncoder,
|
| 5 |
+
DPRQuestionEncoderTokenizer,
|
| 6 |
+
)
|
| 7 |
from datasets import load_dataset
|
| 8 |
import torch
|
| 9 |
+
import os.path
|
| 10 |
|
| 11 |
+
# Hacky fix for FAISS error on macOS
|
| 12 |
+
# See https://stackoverflow.com/a/63374568/4545692
|
| 13 |
+
import os
|
| 14 |
|
| 15 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Retriever:
|
| 19 |
"""A class used to retrieve relevant documents based on some query.
|
| 20 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
| 21 |
"""
|
|
|
|
| 32 |
|
| 33 |
# Context encoding and tokenization
|
| 34 |
self.ctx_encoder = DPRContextEncoder.from_pretrained(
|
| 35 |
+
"facebook/dpr-ctx_encoder-single-nq-base"
|
| 36 |
+
)
|
| 37 |
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
|
| 38 |
+
"facebook/dpr-ctx_encoder-single-nq-base"
|
| 39 |
+
)
|
| 40 |
|
| 41 |
# Question encoding and tokenization
|
| 42 |
self.q_encoder = DPRQuestionEncoder.from_pretrained(
|
| 43 |
+
"facebook/dpr-question_encoder-single-nq-base"
|
| 44 |
+
)
|
| 45 |
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
| 46 |
+
"facebook/dpr-question_encoder-single-nq-base"
|
| 47 |
+
)
|
| 48 |
|
| 49 |
# Dataset building
|
| 50 |
self.dataset = self.__init_dataset(dataset)
|
| 51 |
|
| 52 |
+
def __init_dataset(self,
|
| 53 |
+
dataset: str,
|
| 54 |
+
fname: str = "./models/paragraphs_embedding.faiss"):
|
| 55 |
"""Loads the dataset and adds FAISS embeddings.
|
| 56 |
|
| 57 |
Args:
|
| 58 |
dataset (str): A HuggingFace dataset name.
|
| 59 |
+
fname (str): The name to use to save the embeddings to disk for
|
| 60 |
+
faster loading after the first run.
|
| 61 |
|
| 62 |
Returns:
|
| 63 |
Dataset: A dataset with a new column 'embeddings' containing FAISS
|
| 64 |
embeddings.
|
| 65 |
"""
|
|
|
|
|
|
|
| 66 |
# Load dataset
|
| 67 |
+
ds = load_dataset(dataset, name="paragraphs")["train"]
|
| 68 |
|
| 69 |
+
if os.path.exists(fname):
|
| 70 |
+
# If we already have FAISS embeddings, load them from disk
|
| 71 |
+
ds.load_faiss_index('embeddings', fname)
|
| 72 |
+
return ds
|
| 73 |
+
else:
|
| 74 |
+
# If there are no FAISS embeddings, generate them
|
| 75 |
+
def embed(row):
|
| 76 |
+
# Inline helper function to perform embedding
|
| 77 |
+
p = row["text"]
|
| 78 |
+
tok = self.ctx_tokenizer(
|
| 79 |
+
p, return_tensors="pt", truncation=True)
|
| 80 |
+
enc = self.ctx_encoder(**tok)[0][0].numpy()
|
| 81 |
+
return {"embeddings": enc}
|
| 82 |
+
|
| 83 |
+
# Add FAISS embeddings
|
| 84 |
+
ds_with_embeddings = ds.map(embed)
|
| 85 |
+
|
| 86 |
+
ds_with_embeddings.add_faiss_index(column="embeddings")
|
| 87 |
|
| 88 |
+
# save dataset w/ embeddings
|
| 89 |
+
os.makedirs("./models/", exist_ok=True)
|
| 90 |
+
ds_with_embeddings.save_faiss_index("embeddings", fname)
|
| 91 |
|
| 92 |
+
return ds_with_embeddings
|
|
|
|
|
|
|
| 93 |
|
| 94 |
def retrieve(self, query: str, k: int = 5):
|
| 95 |
"""Retrieve the top k matches for a search query.
|
|
|
|
| 105 |
|
| 106 |
def embed(q):
|
| 107 |
# Inline helper function to perform embedding
|
| 108 |
+
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
|
| 109 |
return self.q_encoder(**tok)[0][0].numpy()
|
| 110 |
|
| 111 |
question_embedding = embed(query)
|
| 112 |
scores, results = self.dataset.get_nearest_examples(
|
| 113 |
+
"embeddings", question_embedding, k=k
|
| 114 |
+
)
|
| 115 |
return scores, results
|
poetry.lock
CHANGED
|
@@ -51,6 +51,18 @@ docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"]
|
|
| 51 |
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"]
|
| 52 |
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"]
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
[[package]]
|
| 55 |
name = "certifi"
|
| 56 |
version = "2021.10.8"
|
|
@@ -460,6 +472,14 @@ python-versions = "*"
|
|
| 460 |
docs = ["sphinx", "sphinx-rtd-theme", "setuptools-rust"]
|
| 461 |
testing = ["pytest", "requests", "numpy", "datasets"]
|
| 462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
[[package]]
|
| 464 |
name = "torch"
|
| 465 |
version = "1.11.0"
|
|
@@ -590,7 +610,7 @@ multidict = ">=4.0"
|
|
| 590 |
[metadata]
|
| 591 |
lock-version = "1.1"
|
| 592 |
python-versions = "^3.8"
|
| 593 |
-
content-hash = "
|
| 594 |
|
| 595 |
[metadata.files]
|
| 596 |
aiohttp = [
|
|
@@ -679,6 +699,10 @@ attrs = [
|
|
| 679 |
{file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"},
|
| 680 |
{file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"},
|
| 681 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
certifi = [
|
| 683 |
{file = "certifi-2021.10.8-py2.py3-none-any.whl", hash = "sha256:d62a0163eb4c2344ac042ab2bdf75399a71a2d8c7d47eac2e2ee91b9d6339569"},
|
| 684 |
{file = "certifi-2021.10.8.tar.gz", hash = "sha256:78884e7c1d4b00ce3cea67b44566851c4343c120abd683433ce934a68ea58872"},
|
|
@@ -1161,6 +1185,10 @@ tokenizers = [
|
|
| 1161 |
{file = "tokenizers-0.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:b28966c68a2cdecd5120f4becea159eebe0335b8202e21e292eb381031026edc"},
|
| 1162 |
{file = "tokenizers-0.11.6.tar.gz", hash = "sha256:562b2022faf0882586c915385620d1f11798fc1b32bac55353a530132369a6d0"},
|
| 1163 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1164 |
torch = [
|
| 1165 |
{file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
|
| 1166 |
{file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
|
|
|
|
| 51 |
tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"]
|
| 52 |
tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"]
|
| 53 |
|
| 54 |
+
[[package]]
|
| 55 |
+
name = "autopep8"
|
| 56 |
+
version = "1.6.0"
|
| 57 |
+
description = "A tool that automatically formats Python code to conform to the PEP 8 style guide"
|
| 58 |
+
category = "dev"
|
| 59 |
+
optional = false
|
| 60 |
+
python-versions = "*"
|
| 61 |
+
|
| 62 |
+
[package.dependencies]
|
| 63 |
+
pycodestyle = ">=2.8.0"
|
| 64 |
+
toml = "*"
|
| 65 |
+
|
| 66 |
[[package]]
|
| 67 |
name = "certifi"
|
| 68 |
version = "2021.10.8"
|
|
|
|
| 472 |
docs = ["sphinx", "sphinx-rtd-theme", "setuptools-rust"]
|
| 473 |
testing = ["pytest", "requests", "numpy", "datasets"]
|
| 474 |
|
| 475 |
+
[[package]]
|
| 476 |
+
name = "toml"
|
| 477 |
+
version = "0.10.2"
|
| 478 |
+
description = "Python Library for Tom's Obvious, Minimal Language"
|
| 479 |
+
category = "dev"
|
| 480 |
+
optional = false
|
| 481 |
+
python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
|
| 482 |
+
|
| 483 |
[[package]]
|
| 484 |
name = "torch"
|
| 485 |
version = "1.11.0"
|
|
|
|
| 610 |
[metadata]
|
| 611 |
lock-version = "1.1"
|
| 612 |
python-versions = "^3.8"
|
| 613 |
+
content-hash = "227b922ee14abf36ca75bb238d239d712bed9213d54c567996566d465e465733"
|
| 614 |
|
| 615 |
[metadata.files]
|
| 616 |
aiohttp = [
|
|
|
|
| 699 |
{file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"},
|
| 700 |
{file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"},
|
| 701 |
]
|
| 702 |
+
autopep8 = [
|
| 703 |
+
{file = "autopep8-1.6.0-py2.py3-none-any.whl", hash = "sha256:ed77137193bbac52d029a52c59bec1b0629b5a186c495f1eb21b126ac466083f"},
|
| 704 |
+
{file = "autopep8-1.6.0.tar.gz", hash = "sha256:44f0932855039d2c15c4510d6df665e4730f2b8582704fa48f9c55bd3e17d979"},
|
| 705 |
+
]
|
| 706 |
certifi = [
|
| 707 |
{file = "certifi-2021.10.8-py2.py3-none-any.whl", hash = "sha256:d62a0163eb4c2344ac042ab2bdf75399a71a2d8c7d47eac2e2ee91b9d6339569"},
|
| 708 |
{file = "certifi-2021.10.8.tar.gz", hash = "sha256:78884e7c1d4b00ce3cea67b44566851c4343c120abd683433ce934a68ea58872"},
|
|
|
|
| 1185 |
{file = "tokenizers-0.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:b28966c68a2cdecd5120f4becea159eebe0335b8202e21e292eb381031026edc"},
|
| 1186 |
{file = "tokenizers-0.11.6.tar.gz", hash = "sha256:562b2022faf0882586c915385620d1f11798fc1b32bac55353a530132369a6d0"},
|
| 1187 |
]
|
| 1188 |
+
toml = [
|
| 1189 |
+
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
| 1190 |
+
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
| 1191 |
+
]
|
| 1192 |
torch = [
|
| 1193 |
{file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
|
| 1194 |
{file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
|
pyproject.toml
CHANGED
|
@@ -14,6 +14,7 @@ faiss-cpu = "^1.7.2"
|
|
| 14 |
|
| 15 |
[tool.poetry.dev-dependencies]
|
| 16 |
flake8 = "^4.0.1"
|
|
|
|
| 17 |
|
| 18 |
[build-system]
|
| 19 |
requires = ["poetry-core>=1.0.0"]
|
|
|
|
| 14 |
|
| 15 |
[tool.poetry.dev-dependencies]
|
| 16 |
flake8 = "^4.0.1"
|
| 17 |
+
autopep8 = "^1.6.0"
|
| 18 |
|
| 19 |
[build-system]
|
| 20 |
requires = ["poetry-core>=1.0.0"]
|