#!/usr/bin/env python3 """ Vector Store Compatibility Wrapper VedaMD Medical RAG - Compatibility Fix This module provides a compatibility wrapper to handle dimension mismatches between Clinical ModernBERT embeddings (768d) and existing vector store (384d). TEMPORARY SOLUTION: - Allows testing of enhanced medical RAG pipeline - Handles dimension conversion for compatibility - Maintains medical domain benefits where possible FUTURE: Rebuild vector store with Clinical ModernBERT for full 768d benefits """ import numpy as np import logging from typing import List from sentence_transformers import SentenceTransformer from simple_vector_store import SimpleVectorStore, SearchResult class CompatibleMedicalVectorStore: """ Compatibility wrapper for vector store dimension mismatches """ def __init__(self, repo_id: str = "sniro23/VedaMD-Vector-Store"): self.setup_logging() # Initialize both embedding models for compatibility self.logger.info("๐Ÿ”ง Initializing Vector Store Compatibility Layer...") # Original vector store with existing embeddings (384d) self.original_vector_store = SimpleVectorStore( repo_id=repo_id, embedding_model_name="sentence-transformers/all-MiniLM-L6-v2" # Match original ) self.logger.info("โœ… Original vector store loaded (384d)") # Clinical ModernBERT for enhanced medical understanding (768d) self.clinical_embedder = SentenceTransformer("Simonlee711/Clinical_ModernBERT") self.logger.info("โœ… Clinical ModernBERT loaded (768d)") self.logger.info("๐ŸŽฏ Vector Store Compatibility Layer ready") def setup_logging(self): """Setup logging""" logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') self.logger = logging.getLogger(__name__) def search(self, query: str, k: int = 5) -> List[SearchResult]: """ Search with compatibility layer - uses original vector store for retrieval """ self.logger.info(f"๐Ÿ” Searching with compatibility layer: {query[:50]}...") # Use original vector store for retrieval (384d compatibility) results = self.original_vector_store.search(query=query, k=k) # Enhance results with Clinical ModernBERT similarity scoring if results: enhanced_results = self._enhance_with_clinical_similarity(query, results) self.logger.info(f"โœ… Retrieved {len(enhanced_results)} documents with medical enhancement") return enhanced_results return results def _enhance_with_clinical_similarity(self, query: str, results: List[SearchResult]) -> List[SearchResult]: """ Enhance search results with Clinical ModernBERT similarity scoring """ try: # Get clinical embedding for query query_clinical_embedding = self.clinical_embedder.encode([query]) # Calculate clinical similarity for each result enhanced_results = [] for result in results: # Get clinical embedding for document content doc_clinical_embedding = self.clinical_embedder.encode([result.content]) # Calculate clinical similarity clinical_similarity = np.dot(query_clinical_embedding[0], doc_clinical_embedding[0]) / ( np.linalg.norm(query_clinical_embedding[0]) * np.linalg.norm(doc_clinical_embedding[0]) ) # Combine original score with clinical similarity (weighted average) enhanced_score = 0.6 * result.score + 0.4 * clinical_similarity # Create enhanced result enhanced_result = SearchResult( content=result.content, score=enhanced_score, metadata={ **result.metadata, 'original_score': result.score, 'clinical_similarity': float(clinical_similarity), 'enhanced_score': float(enhanced_score) } ) enhanced_results.append(enhanced_result) # Sort by enhanced score enhanced_results.sort(key=lambda x: x.score, reverse=True) return enhanced_results except Exception as e: self.logger.warning(f"Clinical enhancement failed: {e}. Using original results.") return results def test_compatible_vector_store(): """Test the compatible vector store""" print("๐Ÿงช Testing Compatible Vector Store") store = CompatibleMedicalVectorStore() # Test medical queries test_queries = [ "preeclampsia management protocol", "postpartum hemorrhage treatment", "contraindicated medications pregnancy" ] for query in test_queries: print(f"\n๐Ÿ” Query: {query}") results = store.search(query, k=3) for i, result in enumerate(results, 1): print(f" {i}. Score: {result.score:.3f}") if 'clinical_similarity' in result.metadata: print(f" Clinical Similarity: {result.metadata['clinical_similarity']:.3f}") print(f" Content: {result.content[:100]}...") print(f"\nโœ… Compatible Vector Store Test Completed") if __name__ == "__main__": test_compatible_vector_store()