from transformers import RagTokenizer, RagSequenceForGeneration from config.rag_config import RAGConfig from src.embedder import Embedder from src.retriever import Retriever class RAGPipeline: def __init__(self, config: RAGConfig, docs, doc_embeddings): self.config = config self.embedder = Embedder(config) self.retriever = Retriever(doc_embeddings, docs, config) self.tokenizer = RagTokenizer.from_pretrained(config.llm_model_name) self.model = RagSequenceForGeneration.from_pretrained(config.llm_model_name) def ask(self, query): query_emb = self.embedder.embed_texts([query])[0] retrieved = self.retriever.retrieve(query_emb) context = "\n".join([r[0] for r in retrieved]) input_text = f"Question: {query}\nContext: {context}" inputs = self.tokenizer(input_text, return_tensors="pt") output = self.model.generate( **inputs, **self.config.generation_kwargs ) return self.tokenizer.batch_decode(output, skip_special_tokens=True)[0], retrieved