Spaces:
Sleeping
Sleeping
File size: 11,598 Bytes
01f0120 b4971bd 01f0120 b4971bd 01f0120 b4971bd 01f0120 b4971bd 01f0120 b4971bd 01f0120 b4971bd 01f0120 b4971bd 01f0120 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 |
"""
Simple Vector Store for Medical RAG - Runtime Version
This version is designed to load a pre-computed vector store from the Hugging Face Hub.
"""
import os
import json
import logging
import time
from typing import List, Dict, Any, Optional
from pathlib import Path
import numpy as np
from dataclasses import dataclass
import faiss
from sentence_transformers import SentenceTransformer
from langchain_core.documents import Document
from huggingface_hub import hf_hub_download
@dataclass
class SearchResult:
"""Simple search result structure"""
content: str
score: float
metadata: Dict[str, Any]
class SimpleVectorStore:
"""
A simplified vector store that loads its index and documents from the Hugging Face Hub.
It does not contain any logic for creating embeddings or building an index at runtime.
"""
def __init__(self,
repo_id: str = None,
local_dir: str = None,
embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
"""
Initializes the vector store by loading from HF Hub or local directory.
Args:
repo_id (str): The Hugging Face Hub repository ID (e.g., "user/repo-name"). Optional if local_dir provided.
local_dir (str): Local directory containing vector store files. Optional if repo_id provided.
embedding_model_name (str): The embedding model to use for query embedding.
Defaults to sentence-transformers/all-MiniLM-L6-v2 (384d).
"""
if not repo_id and not local_dir:
raise ValueError("Either repo_id or local_dir must be provided")
self.repo_id = repo_id
self.local_dir = local_dir
self.embedding_model_name = embedding_model_name
self.setup_logging()
# Log the embedding model choice for medical domain
if "Clinical" in embedding_model_name or "Bio" in embedding_model_name:
self.logger.info(f"π₯ Using medical domain embedding model: {embedding_model_name}")
else:
self.logger.warning(f"β οΈ Using general domain embedding model: {embedding_model_name}")
self.embedding_model = None
self.index = None
self.documents = []
self.metadata = []
self._initialize_embedding_model()
# Load from local directory or HF Hub
if self.local_dir:
self.load_from_local_directory()
else:
self.load_from_huggingface_hub()
def setup_logging(self):
"""Setup logging for the vector store"""
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
def _initialize_embedding_model(self):
"""Initialize the sentence transformer model for creating query embeddings."""
try:
self.logger.info(f"Loading embedding model: {self.embedding_model_name}")
self.embedding_model = SentenceTransformer(self.embedding_model_name)
self.logger.info("Embedding model loaded successfully.")
except Exception as e:
self.logger.error(f"Error loading embedding model: {e}")
raise
def load_from_local_directory(self):
"""
Loads the vector store artifacts from a local directory.
"""
self.logger.info(f"Loading vector store from local directory: {self.local_dir}")
try:
local_path = Path(self.local_dir)
# Check if directory exists
if not local_path.exists():
raise FileNotFoundError(f"Local directory not found: {self.local_dir}")
# Load the FAISS index
index_path = local_path / "faiss_index.bin"
self.index = faiss.read_index(str(index_path))
self.logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors from local directory.")
# Load documents and metadata
docs_path = local_path / "documents.json"
metadata_path = local_path / "metadata.json"
config_path = local_path / "config.json"
with open(docs_path, 'r', encoding='utf-8') as f:
page_contents = json.load(f)
with open(metadata_path, 'r', encoding='utf-8') as f:
metadatas = json.load(f)
# Combine them to reconstruct the documents
if len(page_contents) != len(metadatas):
raise ValueError("Mismatch between number of documents and metadata entries.")
for i in range(len(page_contents)):
content = page_contents[i] if isinstance(page_contents[i], str) else page_contents[i].get('page_content', '')
metadata = metadatas[i] if isinstance(metadatas[i], dict) else {}
# Ensure a valid citation exists
if not metadata.get('citation'):
source_path = metadata.get('source', 'Unknown')
if source_path != 'Unknown':
metadata['citation'] = Path(source_path).stem.replace('-', ' ').title()
else:
metadata['citation'] = 'Unknown Source'
self.documents.append(Document(page_content=content, metadata=metadata))
self.metadata.append(metadata)
self.logger.info(f"Loaded {len(self.documents)} documents from local directory.")
# Load and log the configuration
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
self.logger.info(f"Vector store configuration loaded: {config}")
except Exception as e:
self.logger.error(f"Failed to load vector store from local directory: {e}")
raise
def load_from_huggingface_hub(self):
"""
Downloads the vector store artifacts from the specified Hugging Face Hub repository and loads them.
"""
self.logger.info(f"Downloading vector store from Hugging Face Hub repo: {self.repo_id}")
try:
# Download the four essential files
index_path = hf_hub_download(repo_id=self.repo_id, filename="faiss_index.bin")
docs_path = hf_hub_download(repo_id=self.repo_id, filename="documents.json")
metadata_path = hf_hub_download(repo_id=self.repo_id, filename="metadata.json") # Download metadata
config_path = hf_hub_download(repo_id=self.repo_id, filename="config.json")
self.logger.info("Vector store files downloaded successfully.")
# Load the FAISS index
self.index = faiss.read_index(index_path)
self.logger.info(f"Loaded FAISS index with {self.index.ntotal} vectors.")
# Load the documents and metadata separately
with open(docs_path, 'r', encoding='utf-8') as f:
page_contents = json.load(f)
with open(metadata_path, 'r', encoding='utf-8') as f:
metadatas = json.load(f)
# Combine them to reconstruct the documents
if len(page_contents) != len(metadatas):
raise ValueError("Mismatch between number of documents and metadata entries.")
for i in range(len(page_contents)):
content = page_contents[i] if isinstance(page_contents[i], str) else page_contents[i].get('page_content', '')
metadata = metadatas[i] if isinstance(metadatas[i], dict) else {}
# FIX: Ensure a valid citation exists.
# If 'citation' is missing or empty, create one from the source file path.
if not metadata.get('citation'):
source_path = metadata.get('source', 'Unknown')
if source_path != 'Unknown':
# Extract the guideline name from the parent directory of the source file
metadata['citation'] = Path(source_path).parent.name.replace('-', ' ').title()
else:
metadata['citation'] = 'Unknown Source'
self.documents.append(Document(page_content=content, metadata=metadata))
self.metadata.append(metadata)
self.logger.info(f"Loaded {len(self.documents)} documents with improved citations.")
# Load and log the configuration
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
self.logger.info(f"Vector store configuration loaded: {config}")
except Exception as e:
self.logger.error(f"Failed to load vector store from Hugging Face Hub: {e}")
raise
def search(self, query: str, k: int = 5) -> List[SearchResult]:
"""
Searches the vector store for the top-k most similar documents to the query.
Args:
query (str): The search query.
k (int): The number of results to return.
Returns:
A list of SearchResult objects.
"""
if not self.index or not self.documents:
self.logger.error("Search attempted but vector store is not initialized.")
return []
# Create an embedding for the query
query_embedding = self.embedding_model.encode([query], normalize_embeddings=True)
# Search the FAISS index
scores, indices = self.index.search(query_embedding.astype('float32'), k)
# Process and return the results
results = []
for score, idx in zip(scores[0], indices[0]):
if idx == -1: continue # Skip invalid indices
doc = self.documents[idx]
results.append(SearchResult(
content=doc.page_content,
score=float(score),
metadata=doc.metadata
))
return results
def main():
"""Main function to test the simple vector store"""
print("π Testing Simple Vector Store v2.0")
print("=" * 60)
try:
# Initialize vector store
vector_store = SimpleVectorStore(
repo_id="user/repo-name"
)
# Test search functionality
print(f"\nπ TESTING SEARCH FUNCTIONALITY:")
test_queries = [
"magnesium sulfate dosage preeclampsia",
"postpartum hemorrhage management",
"fetal heart rate monitoring",
"emergency cesarean delivery"
]
for query in test_queries:
print(f"\nπ Query: '{query}'")
results = vector_store.search(query, k=3)
for i, result in enumerate(results, 1):
print(f" Result {i}: Score={result.score:.3f}, Doc={result.metadata.get('document_name', 'Unknown')}")
print(f" Type={result.metadata.get('content_type', 'general')}")
print(f" Preview: {result.content[:100]}...")
print(f"\nπ Simple Vector Store Testing Complete!")
print(f"β
Successfully loaded vector store with {len(vector_store.documents):,} embeddings")
print(f"β
Search functionality working with high relevance scores")
return vector_store
except Exception as e:
print(f"β Error in simple vector store: {e}")
import traceback
traceback.print_exc()
return None
if __name__ == "__main__":
main() |