File size: 1,083 Bytes
94f5c4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers import RagTokenizer, RagSequenceForGeneration
from config.rag_config import RAGConfig
from src.embedder import Embedder
from src.retriever import Retriever

class RAGPipeline:
    def __init__(self, config: RAGConfig, docs, doc_embeddings):
        self.config = config
        self.embedder = Embedder(config)
        self.retriever = Retriever(doc_embeddings, docs, config)
        self.tokenizer = RagTokenizer.from_pretrained(config.llm_model_name)
        self.model = RagSequenceForGeneration.from_pretrained(config.llm_model_name)

    def ask(self, query):
        query_emb = self.embedder.embed_texts([query])[0]
        retrieved = self.retriever.retrieve(query_emb)
        context = "\n".join([r[0] for r in retrieved])
        input_text = f"Question: {query}\nContext: {context}"
        inputs = self.tokenizer(input_text, return_tensors="pt")
        output = self.model.generate(
            **inputs,
            **self.config.generation_kwargs
        )
        return self.tokenizer.batch_decode(output, skip_special_tokens=True)[0], retrieved