Commit
·
1fb8ae3
1
Parent(s):
2a1de95
decouple ds loading from retriever
Browse files- main.py +5 -2
- src/retrievers/es_retriever.py +5 -1
- src/retrievers/faiss_retriever.py +19 -50
main.py
CHANGED
|
@@ -27,11 +27,14 @@ if __name__ == '__main__':
|
|
| 27 |
|
| 28 |
# logger.info(questions)
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
# Initialize retriever
|
| 31 |
-
retriever = FaissRetriever()
|
| 32 |
|
| 33 |
# Retrieve example
|
| 34 |
-
#random.seed(111)
|
| 35 |
random_index = random.randint(0, len(questions_test["question"])-1)
|
| 36 |
example_q = questions_test["question"][random_index]
|
| 37 |
example_a = questions_test["answer"][random_index]
|
|
|
|
| 27 |
|
| 28 |
# logger.info(questions)
|
| 29 |
|
| 30 |
+
dataset_paragraphs = cast(DatasetDict, load_dataset(
|
| 31 |
+
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
| 32 |
+
|
| 33 |
# Initialize retriever
|
| 34 |
+
retriever = FaissRetriever(dataset_paragraphs)
|
| 35 |
|
| 36 |
# Retrieve example
|
| 37 |
+
# random.seed(111)
|
| 38 |
random_index = random.randint(0, len(questions_test["question"])-1)
|
| 39 |
example_q = questions_test["question"][random_index]
|
| 40 |
example_a = questions_test["answer"][random_index]
|
src/retrievers/es_retriever.py
CHANGED
|
@@ -1,10 +1,14 @@
|
|
|
|
|
| 1 |
from src.utils.log import get_logger
|
|
|
|
|
|
|
| 2 |
|
| 3 |
logger = get_logger()
|
| 4 |
|
| 5 |
|
| 6 |
class ESRetriever(Retriever):
|
| 7 |
-
def __init__(self, data_set):
|
|
|
|
| 8 |
pass
|
| 9 |
|
| 10 |
def retrieve(self, query: str, k: int):
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
from src.utils.log import get_logger
|
| 3 |
+
from src.retrievers.base_retriever import Retriever
|
| 4 |
+
|
| 5 |
|
| 6 |
logger = get_logger()
|
| 7 |
|
| 8 |
|
| 9 |
class ESRetriever(Retriever):
|
| 10 |
+
def __init__(self, data_set: ) -> None:
|
| 11 |
+
|
| 12 |
pass
|
| 13 |
|
| 14 |
def retrieve(self, query: str, k: int):
|
src/retrievers/faiss_retriever.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
| 2 |
import os.path
|
| 3 |
|
| 4 |
import torch
|
| 5 |
-
from datasets import load_dataset
|
| 6 |
from transformers import (
|
| 7 |
DPRContextEncoder,
|
| 8 |
DPRContextEncoderTokenizer,
|
|
@@ -26,14 +26,7 @@ class FaissRetriever(Retriever):
|
|
| 26 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
| 27 |
"""
|
| 28 |
|
| 29 |
-
def __init__(self,
|
| 30 |
-
"""Initialize the retriever
|
| 31 |
-
|
| 32 |
-
Args:
|
| 33 |
-
dataset (str, optional): The dataset to train on. Assumes the
|
| 34 |
-
information is stored in a column named 'text'. Defaults to
|
| 35 |
-
"GroNLP/ik-nlp-22_slp".
|
| 36 |
-
"""
|
| 37 |
torch.set_grad_enabled(False)
|
| 38 |
|
| 39 |
# Context encoding and tokenization
|
|
@@ -52,36 +45,22 @@ class FaissRetriever(Retriever):
|
|
| 52 |
"facebook/dpr-question_encoder-single-nq-base"
|
| 53 |
)
|
| 54 |
|
| 55 |
-
|
| 56 |
-
self.
|
| 57 |
-
|
|
|
|
| 58 |
|
| 59 |
-
def
|
| 60 |
self,
|
| 61 |
-
dataset_name: str,
|
| 62 |
-
embedding_path: str = "./src/models/paragraphs_embedding.faiss",
|
| 63 |
force_new_embedding: bool = False):
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
Returns:
|
| 72 |
-
Dataset: A dataset with a new column 'embeddings' containing FAISS
|
| 73 |
-
embeddings.
|
| 74 |
-
"""
|
| 75 |
-
# Load dataset
|
| 76 |
-
ds = load_dataset(dataset_name, name="paragraphs")[
|
| 77 |
-
"train"] # type: ignore
|
| 78 |
-
|
| 79 |
-
if not force_new_embedding and os.path.exists(embedding_path):
|
| 80 |
-
# If we already have FAISS embeddings, load them from disk
|
| 81 |
-
ds.load_faiss_index('embeddings', embedding_path) # type: ignore
|
| 82 |
return ds
|
| 83 |
else:
|
| 84 |
-
# If there are no FAISS embeddings, generate them
|
| 85 |
def embed(row):
|
| 86 |
# Inline helper function to perform embedding
|
| 87 |
p = row["text"]
|
|
@@ -91,35 +70,25 @@ class FaissRetriever(Retriever):
|
|
| 91 |
return {"embeddings": enc}
|
| 92 |
|
| 93 |
# Add FAISS embeddings
|
| 94 |
-
|
| 95 |
|
| 96 |
-
|
| 97 |
|
| 98 |
# save dataset w/ embeddings
|
| 99 |
os.makedirs("./src/models/", exist_ok=True)
|
| 100 |
-
|
|
|
|
| 101 |
|
| 102 |
-
return
|
| 103 |
|
| 104 |
def retrieve(self, query: str, k: int = 5):
|
| 105 |
-
"""Retrieve the top k matches for a search query.
|
| 106 |
-
|
| 107 |
-
Args:
|
| 108 |
-
query (str): A search query
|
| 109 |
-
k (int, optional): The number of documents to retrieve. Defaults to
|
| 110 |
-
5.
|
| 111 |
-
|
| 112 |
-
Returns:
|
| 113 |
-
tuple: A tuple of lists of scores and results.
|
| 114 |
-
"""
|
| 115 |
-
|
| 116 |
def embed(q):
|
| 117 |
# Inline helper function to perform embedding
|
| 118 |
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
|
| 119 |
return self.q_encoder(**tok)[0][0].numpy()
|
| 120 |
|
| 121 |
question_embedding = embed(query)
|
| 122 |
-
scores, results = self.
|
| 123 |
"embeddings", question_embedding, k=k
|
| 124 |
)
|
| 125 |
|
|
|
|
| 2 |
import os.path
|
| 3 |
|
| 4 |
import torch
|
| 5 |
+
from datasets import DatasetDict, load_dataset
|
| 6 |
from transformers import (
|
| 7 |
DPRContextEncoder,
|
| 8 |
DPRContextEncoderTokenizer,
|
|
|
|
| 26 |
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
| 27 |
"""
|
| 28 |
|
| 29 |
+
def __init__(self, dataset: DatasetDict, embedding_path: str = "./src/models/paragraphs_embedding.faiss") -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
torch.set_grad_enabled(False)
|
| 31 |
|
| 32 |
# Context encoding and tokenization
|
|
|
|
| 45 |
"facebook/dpr-question_encoder-single-nq-base"
|
| 46 |
)
|
| 47 |
|
| 48 |
+
self.dataset = dataset
|
| 49 |
+
self.embedding_path = embedding_path
|
| 50 |
+
|
| 51 |
+
self.index = self._init_index()
|
| 52 |
|
| 53 |
+
def _init_index(
|
| 54 |
self,
|
|
|
|
|
|
|
| 55 |
force_new_embedding: bool = False):
|
| 56 |
+
|
| 57 |
+
ds = self.dataset["train"]
|
| 58 |
+
|
| 59 |
+
if not force_new_embedding and os.path.exists(self.embedding_path):
|
| 60 |
+
ds.load_faiss_index(
|
| 61 |
+
'embeddings', self.embedding_path) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
return ds
|
| 63 |
else:
|
|
|
|
| 64 |
def embed(row):
|
| 65 |
# Inline helper function to perform embedding
|
| 66 |
p = row["text"]
|
|
|
|
| 70 |
return {"embeddings": enc}
|
| 71 |
|
| 72 |
# Add FAISS embeddings
|
| 73 |
+
index = ds.map(embed) # type: ignore
|
| 74 |
|
| 75 |
+
index.add_faiss_index(column="embeddings")
|
| 76 |
|
| 77 |
# save dataset w/ embeddings
|
| 78 |
os.makedirs("./src/models/", exist_ok=True)
|
| 79 |
+
index.save_faiss_index(
|
| 80 |
+
"embeddings", self.embedding_path)
|
| 81 |
|
| 82 |
+
return index
|
| 83 |
|
| 84 |
def retrieve(self, query: str, k: int = 5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
def embed(q):
|
| 86 |
# Inline helper function to perform embedding
|
| 87 |
tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
|
| 88 |
return self.q_encoder(**tok)[0][0].numpy()
|
| 89 |
|
| 90 |
question_embedding = embed(query)
|
| 91 |
+
scores, results = self.index.get_nearest_examples(
|
| 92 |
"embeddings", question_embedding, k=k
|
| 93 |
)
|
| 94 |
|