VedaMD-Backend-v2 / src /vector_store_compatibility.py
sniro23's picture
VedaMD Enhanced: Clean deployment with 5x Enhanced Medical RAG System
01f0120
#!/usr/bin/env python3
"""
Vector Store Compatibility Wrapper
VedaMD Medical RAG - Compatibility Fix
This module provides a compatibility wrapper to handle dimension mismatches
between Clinical ModernBERT embeddings (768d) and existing vector store (384d).
TEMPORARY SOLUTION:
- Allows testing of enhanced medical RAG pipeline
- Handles dimension conversion for compatibility
- Maintains medical domain benefits where possible
FUTURE: Rebuild vector store with Clinical ModernBERT for full 768d benefits
"""
import numpy as np
import logging
from typing import List
from sentence_transformers import SentenceTransformer
from simple_vector_store import SimpleVectorStore, SearchResult
class CompatibleMedicalVectorStore:
"""
Compatibility wrapper for vector store dimension mismatches
"""
def __init__(self, repo_id: str = "sniro23/VedaMD-Vector-Store"):
self.setup_logging()
# Initialize both embedding models for compatibility
self.logger.info("πŸ”§ Initializing Vector Store Compatibility Layer...")
# Original vector store with existing embeddings (384d)
self.original_vector_store = SimpleVectorStore(
repo_id=repo_id,
embedding_model_name="sentence-transformers/all-MiniLM-L6-v2" # Match original
)
self.logger.info("βœ… Original vector store loaded (384d)")
# Clinical ModernBERT for enhanced medical understanding (768d)
self.clinical_embedder = SentenceTransformer("Simonlee711/Clinical_ModernBERT")
self.logger.info("βœ… Clinical ModernBERT loaded (768d)")
self.logger.info("🎯 Vector Store Compatibility Layer ready")
def setup_logging(self):
"""Setup logging"""
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
def search(self, query: str, k: int = 5) -> List[SearchResult]:
"""
Search with compatibility layer - uses original vector store for retrieval
"""
self.logger.info(f"πŸ” Searching with compatibility layer: {query[:50]}...")
# Use original vector store for retrieval (384d compatibility)
results = self.original_vector_store.search(query=query, k=k)
# Enhance results with Clinical ModernBERT similarity scoring
if results:
enhanced_results = self._enhance_with_clinical_similarity(query, results)
self.logger.info(f"βœ… Retrieved {len(enhanced_results)} documents with medical enhancement")
return enhanced_results
return results
def _enhance_with_clinical_similarity(self, query: str, results: List[SearchResult]) -> List[SearchResult]:
"""
Enhance search results with Clinical ModernBERT similarity scoring
"""
try:
# Get clinical embedding for query
query_clinical_embedding = self.clinical_embedder.encode([query])
# Calculate clinical similarity for each result
enhanced_results = []
for result in results:
# Get clinical embedding for document content
doc_clinical_embedding = self.clinical_embedder.encode([result.content])
# Calculate clinical similarity
clinical_similarity = np.dot(query_clinical_embedding[0], doc_clinical_embedding[0]) / (
np.linalg.norm(query_clinical_embedding[0]) * np.linalg.norm(doc_clinical_embedding[0])
)
# Combine original score with clinical similarity (weighted average)
enhanced_score = 0.6 * result.score + 0.4 * clinical_similarity
# Create enhanced result
enhanced_result = SearchResult(
content=result.content,
score=enhanced_score,
metadata={
**result.metadata,
'original_score': result.score,
'clinical_similarity': float(clinical_similarity),
'enhanced_score': float(enhanced_score)
}
)
enhanced_results.append(enhanced_result)
# Sort by enhanced score
enhanced_results.sort(key=lambda x: x.score, reverse=True)
return enhanced_results
except Exception as e:
self.logger.warning(f"Clinical enhancement failed: {e}. Using original results.")
return results
def test_compatible_vector_store():
"""Test the compatible vector store"""
print("πŸ§ͺ Testing Compatible Vector Store")
store = CompatibleMedicalVectorStore()
# Test medical queries
test_queries = [
"preeclampsia management protocol",
"postpartum hemorrhage treatment",
"contraindicated medications pregnancy"
]
for query in test_queries:
print(f"\nπŸ” Query: {query}")
results = store.search(query, k=3)
for i, result in enumerate(results, 1):
print(f" {i}. Score: {result.score:.3f}")
if 'clinical_similarity' in result.metadata:
print(f" Clinical Similarity: {result.metadata['clinical_similarity']:.3f}")
print(f" Content: {result.content[:100]}...")
print(f"\nβœ… Compatible Vector Store Test Completed")
if __name__ == "__main__":
test_compatible_vector_store()