Sebastiangmz's picture
Initial CodeRAG deploy
d557d77
"""Retrieval module for semantic search."""
from typing import Optional
from coderag.config import get_settings
from coderag.indexing.embeddings import EmbeddingGenerator
from coderag.indexing.vectorstore import VectorStore
from coderag.logging import get_logger
from coderag.models.chunk import Chunk
from coderag.models.response import RetrievedChunk
logger = get_logger(__name__)
class Retriever:
"""Retrieves relevant chunks for a query."""
def __init__(
self,
vectorstore: Optional[VectorStore] = None,
embedder: Optional[EmbeddingGenerator] = None,
) -> None:
settings = get_settings()
self.vectorstore = vectorstore or VectorStore()
self.embedder = embedder or EmbeddingGenerator()
self.default_top_k = settings.retrieval.default_top_k
self.max_top_k = settings.retrieval.max_top_k
self.similarity_threshold = settings.retrieval.similarity_threshold
def retrieve(
self,
query: str,
repo_id: str,
top_k: Optional[int] = None,
similarity_threshold: Optional[float] = None,
) -> list[RetrievedChunk]:
top_k = min(top_k or self.default_top_k, self.max_top_k)
threshold = similarity_threshold if similarity_threshold is not None else self.similarity_threshold
logger.info("Retrieving chunks", query=query[:100], repo_id=repo_id, top_k=top_k)
# Generate query embedding
query_embedding = self.embedder.generate_embedding(query, is_query=True)
# Search vector store
results = self.vectorstore.query(
query_embedding=query_embedding,
repo_id=repo_id,
top_k=top_k,
similarity_threshold=threshold,
)
# Convert to RetrievedChunk
retrieved_chunks = []
for chunk, score in results:
retrieved_chunk = RetrievedChunk(
chunk_id=chunk.id,
content=chunk.content,
file_path=chunk.file_path,
start_line=chunk.start_line,
end_line=chunk.end_line,
relevance_score=score,
chunk_type=chunk.chunk_type.value,
name=chunk.name,
)
retrieved_chunks.append(retrieved_chunk)
logger.info("Chunks retrieved", count=len(retrieved_chunks))
return retrieved_chunks
def retrieve_with_context(
self,
query: str,
repo_id: str,
top_k: Optional[int] = None,
) -> tuple[list[RetrievedChunk], str]:
chunks = self.retrieve(query, repo_id, top_k)
# Build context string for LLM
context_parts = []
for i, chunk in enumerate(chunks, 1):
context_parts.append(
f"[{i}] {chunk.citation}\n"
f"Type: {chunk.chunk_type}"
f"{f' | Name: {chunk.name}' if chunk.name else ''}\n"
f"```\n{chunk.content}\n```\n"
)
context = "\n".join(context_parts) if context_parts else "No relevant code found."
return chunks, context