Spaces:
Sleeping
Sleeping
| """Retrieval module for semantic search.""" | |
| from typing import Optional | |
| from coderag.config import get_settings | |
| from coderag.indexing.embeddings import EmbeddingGenerator | |
| from coderag.indexing.vectorstore import VectorStore | |
| from coderag.logging import get_logger | |
| from coderag.models.chunk import Chunk | |
| from coderag.models.response import RetrievedChunk | |
| logger = get_logger(__name__) | |
| class Retriever: | |
| """Retrieves relevant chunks for a query.""" | |
| def __init__( | |
| self, | |
| vectorstore: Optional[VectorStore] = None, | |
| embedder: Optional[EmbeddingGenerator] = None, | |
| ) -> None: | |
| settings = get_settings() | |
| self.vectorstore = vectorstore or VectorStore() | |
| self.embedder = embedder or EmbeddingGenerator() | |
| self.default_top_k = settings.retrieval.default_top_k | |
| self.max_top_k = settings.retrieval.max_top_k | |
| self.similarity_threshold = settings.retrieval.similarity_threshold | |
| def retrieve( | |
| self, | |
| query: str, | |
| repo_id: str, | |
| top_k: Optional[int] = None, | |
| similarity_threshold: Optional[float] = None, | |
| ) -> list[RetrievedChunk]: | |
| top_k = min(top_k or self.default_top_k, self.max_top_k) | |
| threshold = similarity_threshold if similarity_threshold is not None else self.similarity_threshold | |
| logger.info("Retrieving chunks", query=query[:100], repo_id=repo_id, top_k=top_k) | |
| # Generate query embedding | |
| query_embedding = self.embedder.generate_embedding(query, is_query=True) | |
| # Search vector store | |
| results = self.vectorstore.query( | |
| query_embedding=query_embedding, | |
| repo_id=repo_id, | |
| top_k=top_k, | |
| similarity_threshold=threshold, | |
| ) | |
| # Convert to RetrievedChunk | |
| retrieved_chunks = [] | |
| for chunk, score in results: | |
| retrieved_chunk = RetrievedChunk( | |
| chunk_id=chunk.id, | |
| content=chunk.content, | |
| file_path=chunk.file_path, | |
| start_line=chunk.start_line, | |
| end_line=chunk.end_line, | |
| relevance_score=score, | |
| chunk_type=chunk.chunk_type.value, | |
| name=chunk.name, | |
| ) | |
| retrieved_chunks.append(retrieved_chunk) | |
| logger.info("Chunks retrieved", count=len(retrieved_chunks)) | |
| return retrieved_chunks | |
| def retrieve_with_context( | |
| self, | |
| query: str, | |
| repo_id: str, | |
| top_k: Optional[int] = None, | |
| ) -> tuple[list[RetrievedChunk], str]: | |
| chunks = self.retrieve(query, repo_id, top_k) | |
| # Build context string for LLM | |
| context_parts = [] | |
| for i, chunk in enumerate(chunks, 1): | |
| context_parts.append( | |
| f"[{i}] {chunk.citation}\n" | |
| f"Type: {chunk.chunk_type}" | |
| f"{f' | Name: {chunk.name}' if chunk.name else ''}\n" | |
| f"```\n{chunk.content}\n```\n" | |
| ) | |
| context = "\n".join(context_parts) if context_parts else "No relevant code found." | |
| return chunks, context | |