File size: 5,685 Bytes
01f0120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python3
"""
Vector Store Compatibility Wrapper
VedaMD Medical RAG - Compatibility Fix

This module provides a compatibility wrapper to handle dimension mismatches
between Clinical ModernBERT embeddings (768d) and existing vector store (384d).

TEMPORARY SOLUTION:
- Allows testing of enhanced medical RAG pipeline
- Handles dimension conversion for compatibility
- Maintains medical domain benefits where possible

FUTURE: Rebuild vector store with Clinical ModernBERT for full 768d benefits
"""

import numpy as np
import logging
from typing import List
from sentence_transformers import SentenceTransformer
from simple_vector_store import SimpleVectorStore, SearchResult

class CompatibleMedicalVectorStore:
    """
    Compatibility wrapper for vector store dimension mismatches
    """
    
    def __init__(self, repo_id: str = "sniro23/VedaMD-Vector-Store"):
        self.setup_logging()
        
        # Initialize both embedding models for compatibility
        self.logger.info("πŸ”§ Initializing Vector Store Compatibility Layer...")
        
        # Original vector store with existing embeddings (384d)
        self.original_vector_store = SimpleVectorStore(
            repo_id=repo_id,
            embedding_model_name="sentence-transformers/all-MiniLM-L6-v2"  # Match original
        )
        self.logger.info("βœ… Original vector store loaded (384d)")
        
        # Clinical ModernBERT for enhanced medical understanding (768d)
        self.clinical_embedder = SentenceTransformer("Simonlee711/Clinical_ModernBERT")
        self.logger.info("βœ… Clinical ModernBERT loaded (768d)")
        
        self.logger.info("🎯 Vector Store Compatibility Layer ready")
    
    def setup_logging(self):
        """Setup logging"""
        logging.basicConfig(level=logging.INFO, 
                          format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        self.logger = logging.getLogger(__name__)
    
    def search(self, query: str, k: int = 5) -> List[SearchResult]:
        """
        Search with compatibility layer - uses original vector store for retrieval
        """
        self.logger.info(f"πŸ” Searching with compatibility layer: {query[:50]}...")
        
        # Use original vector store for retrieval (384d compatibility)
        results = self.original_vector_store.search(query=query, k=k)
        
        # Enhance results with Clinical ModernBERT similarity scoring
        if results:
            enhanced_results = self._enhance_with_clinical_similarity(query, results)
            self.logger.info(f"βœ… Retrieved {len(enhanced_results)} documents with medical enhancement")
            return enhanced_results
        
        return results
    
    def _enhance_with_clinical_similarity(self, query: str, results: List[SearchResult]) -> List[SearchResult]:
        """
        Enhance search results with Clinical ModernBERT similarity scoring
        """
        try:
            # Get clinical embedding for query
            query_clinical_embedding = self.clinical_embedder.encode([query])
            
            # Calculate clinical similarity for each result
            enhanced_results = []
            for result in results:
                # Get clinical embedding for document content
                doc_clinical_embedding = self.clinical_embedder.encode([result.content])
                
                # Calculate clinical similarity
                clinical_similarity = np.dot(query_clinical_embedding[0], doc_clinical_embedding[0]) / (
                    np.linalg.norm(query_clinical_embedding[0]) * np.linalg.norm(doc_clinical_embedding[0])
                )
                
                # Combine original score with clinical similarity (weighted average)
                enhanced_score = 0.6 * result.score + 0.4 * clinical_similarity
                
                # Create enhanced result
                enhanced_result = SearchResult(
                    content=result.content,
                    score=enhanced_score,
                    metadata={
                        **result.metadata,
                        'original_score': result.score,
                        'clinical_similarity': float(clinical_similarity),
                        'enhanced_score': float(enhanced_score)
                    }
                )
                enhanced_results.append(enhanced_result)
            
            # Sort by enhanced score
            enhanced_results.sort(key=lambda x: x.score, reverse=True)
            return enhanced_results
            
        except Exception as e:
            self.logger.warning(f"Clinical enhancement failed: {e}. Using original results.")
            return results

def test_compatible_vector_store():
    """Test the compatible vector store"""
    print("πŸ§ͺ Testing Compatible Vector Store")
    
    store = CompatibleMedicalVectorStore()
    
    # Test medical queries
    test_queries = [
        "preeclampsia management protocol",
        "postpartum hemorrhage treatment", 
        "contraindicated medications pregnancy"
    ]
    
    for query in test_queries:
        print(f"\nπŸ” Query: {query}")
        results = store.search(query, k=3)
        
        for i, result in enumerate(results, 1):
            print(f"   {i}. Score: {result.score:.3f}")
            if 'clinical_similarity' in result.metadata:
                print(f"      Clinical Similarity: {result.metadata['clinical_similarity']:.3f}")
            print(f"      Content: {result.content[:100]}...")
    
    print(f"\nβœ… Compatible Vector Store Test Completed")

if __name__ == "__main__":
    test_compatible_vector_store()