Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Vector Store Compatibility Wrapper | |
| VedaMD Medical RAG - Compatibility Fix | |
| This module provides a compatibility wrapper to handle dimension mismatches | |
| between Clinical ModernBERT embeddings (768d) and existing vector store (384d). | |
| TEMPORARY SOLUTION: | |
| - Allows testing of enhanced medical RAG pipeline | |
| - Handles dimension conversion for compatibility | |
| - Maintains medical domain benefits where possible | |
| FUTURE: Rebuild vector store with Clinical ModernBERT for full 768d benefits | |
| """ | |
| import numpy as np | |
| import logging | |
| from typing import List | |
| from sentence_transformers import SentenceTransformer | |
| from simple_vector_store import SimpleVectorStore, SearchResult | |
| class CompatibleMedicalVectorStore: | |
| """ | |
| Compatibility wrapper for vector store dimension mismatches | |
| """ | |
| def __init__(self, repo_id: str = "sniro23/VedaMD-Vector-Store"): | |
| self.setup_logging() | |
| # Initialize both embedding models for compatibility | |
| self.logger.info("π§ Initializing Vector Store Compatibility Layer...") | |
| # Original vector store with existing embeddings (384d) | |
| self.original_vector_store = SimpleVectorStore( | |
| repo_id=repo_id, | |
| embedding_model_name="sentence-transformers/all-MiniLM-L6-v2" # Match original | |
| ) | |
| self.logger.info("β Original vector store loaded (384d)") | |
| # Clinical ModernBERT for enhanced medical understanding (768d) | |
| self.clinical_embedder = SentenceTransformer("Simonlee711/Clinical_ModernBERT") | |
| self.logger.info("β Clinical ModernBERT loaded (768d)") | |
| self.logger.info("π― Vector Store Compatibility Layer ready") | |
| def setup_logging(self): | |
| """Setup logging""" | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| self.logger = logging.getLogger(__name__) | |
| def search(self, query: str, k: int = 5) -> List[SearchResult]: | |
| """ | |
| Search with compatibility layer - uses original vector store for retrieval | |
| """ | |
| self.logger.info(f"π Searching with compatibility layer: {query[:50]}...") | |
| # Use original vector store for retrieval (384d compatibility) | |
| results = self.original_vector_store.search(query=query, k=k) | |
| # Enhance results with Clinical ModernBERT similarity scoring | |
| if results: | |
| enhanced_results = self._enhance_with_clinical_similarity(query, results) | |
| self.logger.info(f"β Retrieved {len(enhanced_results)} documents with medical enhancement") | |
| return enhanced_results | |
| return results | |
| def _enhance_with_clinical_similarity(self, query: str, results: List[SearchResult]) -> List[SearchResult]: | |
| """ | |
| Enhance search results with Clinical ModernBERT similarity scoring | |
| """ | |
| try: | |
| # Get clinical embedding for query | |
| query_clinical_embedding = self.clinical_embedder.encode([query]) | |
| # Calculate clinical similarity for each result | |
| enhanced_results = [] | |
| for result in results: | |
| # Get clinical embedding for document content | |
| doc_clinical_embedding = self.clinical_embedder.encode([result.content]) | |
| # Calculate clinical similarity | |
| clinical_similarity = np.dot(query_clinical_embedding[0], doc_clinical_embedding[0]) / ( | |
| np.linalg.norm(query_clinical_embedding[0]) * np.linalg.norm(doc_clinical_embedding[0]) | |
| ) | |
| # Combine original score with clinical similarity (weighted average) | |
| enhanced_score = 0.6 * result.score + 0.4 * clinical_similarity | |
| # Create enhanced result | |
| enhanced_result = SearchResult( | |
| content=result.content, | |
| score=enhanced_score, | |
| metadata={ | |
| **result.metadata, | |
| 'original_score': result.score, | |
| 'clinical_similarity': float(clinical_similarity), | |
| 'enhanced_score': float(enhanced_score) | |
| } | |
| ) | |
| enhanced_results.append(enhanced_result) | |
| # Sort by enhanced score | |
| enhanced_results.sort(key=lambda x: x.score, reverse=True) | |
| return enhanced_results | |
| except Exception as e: | |
| self.logger.warning(f"Clinical enhancement failed: {e}. Using original results.") | |
| return results | |
| def test_compatible_vector_store(): | |
| """Test the compatible vector store""" | |
| print("π§ͺ Testing Compatible Vector Store") | |
| store = CompatibleMedicalVectorStore() | |
| # Test medical queries | |
| test_queries = [ | |
| "preeclampsia management protocol", | |
| "postpartum hemorrhage treatment", | |
| "contraindicated medications pregnancy" | |
| ] | |
| for query in test_queries: | |
| print(f"\nπ Query: {query}") | |
| results = store.search(query, k=3) | |
| for i, result in enumerate(results, 1): | |
| print(f" {i}. Score: {result.score:.3f}") | |
| if 'clinical_similarity' in result.metadata: | |
| print(f" Clinical Similarity: {result.metadata['clinical_similarity']:.3f}") | |
| print(f" Content: {result.content[:100]}...") | |
| print(f"\nβ Compatible Vector Store Test Completed") | |
| if __name__ == "__main__": | |
| test_compatible_vector_store() |