Spaces:
Sleeping
Sleeping
File size: 5,685 Bytes
01f0120 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
#!/usr/bin/env python3
"""
Vector Store Compatibility Wrapper
VedaMD Medical RAG - Compatibility Fix
This module provides a compatibility wrapper to handle dimension mismatches
between Clinical ModernBERT embeddings (768d) and existing vector store (384d).
TEMPORARY SOLUTION:
- Allows testing of enhanced medical RAG pipeline
- Handles dimension conversion for compatibility
- Maintains medical domain benefits where possible
FUTURE: Rebuild vector store with Clinical ModernBERT for full 768d benefits
"""
import numpy as np
import logging
from typing import List
from sentence_transformers import SentenceTransformer
from simple_vector_store import SimpleVectorStore, SearchResult
class CompatibleMedicalVectorStore:
"""
Compatibility wrapper for vector store dimension mismatches
"""
def __init__(self, repo_id: str = "sniro23/VedaMD-Vector-Store"):
self.setup_logging()
# Initialize both embedding models for compatibility
self.logger.info("π§ Initializing Vector Store Compatibility Layer...")
# Original vector store with existing embeddings (384d)
self.original_vector_store = SimpleVectorStore(
repo_id=repo_id,
embedding_model_name="sentence-transformers/all-MiniLM-L6-v2" # Match original
)
self.logger.info("β
Original vector store loaded (384d)")
# Clinical ModernBERT for enhanced medical understanding (768d)
self.clinical_embedder = SentenceTransformer("Simonlee711/Clinical_ModernBERT")
self.logger.info("β
Clinical ModernBERT loaded (768d)")
self.logger.info("π― Vector Store Compatibility Layer ready")
def setup_logging(self):
"""Setup logging"""
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
def search(self, query: str, k: int = 5) -> List[SearchResult]:
"""
Search with compatibility layer - uses original vector store for retrieval
"""
self.logger.info(f"π Searching with compatibility layer: {query[:50]}...")
# Use original vector store for retrieval (384d compatibility)
results = self.original_vector_store.search(query=query, k=k)
# Enhance results with Clinical ModernBERT similarity scoring
if results:
enhanced_results = self._enhance_with_clinical_similarity(query, results)
self.logger.info(f"β
Retrieved {len(enhanced_results)} documents with medical enhancement")
return enhanced_results
return results
def _enhance_with_clinical_similarity(self, query: str, results: List[SearchResult]) -> List[SearchResult]:
"""
Enhance search results with Clinical ModernBERT similarity scoring
"""
try:
# Get clinical embedding for query
query_clinical_embedding = self.clinical_embedder.encode([query])
# Calculate clinical similarity for each result
enhanced_results = []
for result in results:
# Get clinical embedding for document content
doc_clinical_embedding = self.clinical_embedder.encode([result.content])
# Calculate clinical similarity
clinical_similarity = np.dot(query_clinical_embedding[0], doc_clinical_embedding[0]) / (
np.linalg.norm(query_clinical_embedding[0]) * np.linalg.norm(doc_clinical_embedding[0])
)
# Combine original score with clinical similarity (weighted average)
enhanced_score = 0.6 * result.score + 0.4 * clinical_similarity
# Create enhanced result
enhanced_result = SearchResult(
content=result.content,
score=enhanced_score,
metadata={
**result.metadata,
'original_score': result.score,
'clinical_similarity': float(clinical_similarity),
'enhanced_score': float(enhanced_score)
}
)
enhanced_results.append(enhanced_result)
# Sort by enhanced score
enhanced_results.sort(key=lambda x: x.score, reverse=True)
return enhanced_results
except Exception as e:
self.logger.warning(f"Clinical enhancement failed: {e}. Using original results.")
return results
def test_compatible_vector_store():
"""Test the compatible vector store"""
print("π§ͺ Testing Compatible Vector Store")
store = CompatibleMedicalVectorStore()
# Test medical queries
test_queries = [
"preeclampsia management protocol",
"postpartum hemorrhage treatment",
"contraindicated medications pregnancy"
]
for query in test_queries:
print(f"\nπ Query: {query}")
results = store.search(query, k=3)
for i, result in enumerate(results, 1):
print(f" {i}. Score: {result.score:.3f}")
if 'clinical_similarity' in result.metadata:
print(f" Clinical Similarity: {result.metadata['clinical_similarity']:.3f}")
print(f" Content: {result.content[:100]}...")
print(f"\nβ
Compatible Vector Store Test Completed")
if __name__ == "__main__":
test_compatible_vector_store() |