Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -123,17 +123,21 @@ def check_environment():
|
|
| 123 |
|
| 124 |
class SentenceTransformerRetriever:
|
| 125 |
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
@st.cache_resource(show_spinner=False)
|
| 136 |
-
def _load_model(_self, _model_name: str):
|
| 137 |
"""Load and cache the sentence transformer model"""
|
| 138 |
try:
|
| 139 |
with warnings.catch_warnings():
|
|
@@ -144,11 +148,17 @@ class SentenceTransformerRetriever:
|
|
| 144 |
if not isinstance(test_embedding, torch.Tensor):
|
| 145 |
raise ValueError("Model initialization failed")
|
| 146 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
def get_cache_path(self, data_folder: str = None) -> str:
|
|
|
|
| 148 |
return os.path.join(self.cache_dir, self.cache_file)
|
| 149 |
|
| 150 |
@log_function
|
| 151 |
def save_cache(self, data_folder: str, cache_data: dict):
|
|
|
|
| 152 |
try:
|
| 153 |
cache_path = self.get_cache_path()
|
| 154 |
if os.path.exists(cache_path):
|
|
@@ -162,7 +172,8 @@ class SentenceTransformerRetriever:
|
|
| 162 |
|
| 163 |
@log_function
|
| 164 |
@st.cache_data
|
| 165 |
-
def load_cache(_self, _data_folder: str = None) -> Optional[Dict]:
|
|
|
|
| 166 |
try:
|
| 167 |
cache_path = _self.get_cache_path()
|
| 168 |
if os.path.exists(cache_path):
|
|
@@ -179,6 +190,7 @@ class SentenceTransformerRetriever:
|
|
| 179 |
|
| 180 |
@log_function
|
| 181 |
def encode(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
|
|
|
|
| 182 |
try:
|
| 183 |
embeddings = self.model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
|
| 184 |
return F.normalize(embeddings, p=2, dim=1)
|
|
@@ -188,23 +200,29 @@ class SentenceTransformerRetriever:
|
|
| 188 |
|
| 189 |
@log_function
|
| 190 |
def store_embeddings(self, embeddings: torch.Tensor):
|
|
|
|
| 191 |
self.doc_embeddings = embeddings
|
| 192 |
|
| 193 |
@log_function
|
| 194 |
def search(self, query_embedding: torch.Tensor, k: int, documents: List[str]):
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
class RAGPipeline:
|
| 210 |
def __init__(self, data_folder: str, k: int = 5):
|
|
|
|
| 123 |
|
| 124 |
class SentenceTransformerRetriever:
|
| 125 |
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"):
|
| 126 |
+
try:
|
| 127 |
+
self.device = torch.device("cpu")
|
| 128 |
+
self.model_name = model_name
|
| 129 |
+
self.cache_dir = cache_dir
|
| 130 |
+
self.cache_file = "embeddings.pkl"
|
| 131 |
+
self.doc_embeddings = None
|
| 132 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 133 |
+
# Initialize model using cached method
|
| 134 |
+
self.model = self._load_model(model_name)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logging.error(f"Error initializing SentenceTransformerRetriever: {str(e)}")
|
| 137 |
+
raise
|
| 138 |
|
| 139 |
@st.cache_resource(show_spinner=False)
|
| 140 |
+
def _load_model(_self, _model_name: str):
|
| 141 |
"""Load and cache the sentence transformer model"""
|
| 142 |
try:
|
| 143 |
with warnings.catch_warnings():
|
|
|
|
| 148 |
if not isinstance(test_embedding, torch.Tensor):
|
| 149 |
raise ValueError("Model initialization failed")
|
| 150 |
return model
|
| 151 |
+
except Exception as e:
|
| 152 |
+
logging.error(f"Error loading model: {str(e)}")
|
| 153 |
+
raise
|
| 154 |
+
|
| 155 |
def get_cache_path(self, data_folder: str = None) -> str:
|
| 156 |
+
"""Get the path for cache file"""
|
| 157 |
return os.path.join(self.cache_dir, self.cache_file)
|
| 158 |
|
| 159 |
@log_function
|
| 160 |
def save_cache(self, data_folder: str, cache_data: dict):
|
| 161 |
+
"""Save embeddings to cache"""
|
| 162 |
try:
|
| 163 |
cache_path = self.get_cache_path()
|
| 164 |
if os.path.exists(cache_path):
|
|
|
|
| 172 |
|
| 173 |
@log_function
|
| 174 |
@st.cache_data
|
| 175 |
+
def load_cache(_self, _data_folder: str = None) -> Optional[Dict]:
|
| 176 |
+
"""Load embeddings from cache"""
|
| 177 |
try:
|
| 178 |
cache_path = _self.get_cache_path()
|
| 179 |
if os.path.exists(cache_path):
|
|
|
|
| 190 |
|
| 191 |
@log_function
|
| 192 |
def encode(self, texts: List[str], batch_size: int = 32) -> torch.Tensor:
|
| 193 |
+
"""Encode texts into embeddings"""
|
| 194 |
try:
|
| 195 |
embeddings = self.model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
|
| 196 |
return F.normalize(embeddings, p=2, dim=1)
|
|
|
|
| 200 |
|
| 201 |
@log_function
|
| 202 |
def store_embeddings(self, embeddings: torch.Tensor):
|
| 203 |
+
"""Store embeddings in memory"""
|
| 204 |
self.doc_embeddings = embeddings
|
| 205 |
|
| 206 |
@log_function
|
| 207 |
def search(self, query_embedding: torch.Tensor, k: int, documents: List[str]):
|
| 208 |
+
"""Search for similar documents"""
|
| 209 |
+
try:
|
| 210 |
+
if self.doc_embeddings is None:
|
| 211 |
+
raise ValueError("No document embeddings stored!")
|
| 212 |
+
|
| 213 |
+
similarities = F.cosine_similarity(query_embedding, self.doc_embeddings)
|
| 214 |
+
k = min(k, len(documents))
|
| 215 |
+
scores, indices = torch.topk(similarities, k=k)
|
| 216 |
+
|
| 217 |
+
logging.info(f"\nSimilarity Stats:")
|
| 218 |
+
logging.info(f"Max similarity: {similarities.max().item():.4f}")
|
| 219 |
+
logging.info(f"Mean similarity: {similarities.mean().item():.4f}")
|
| 220 |
+
logging.info(f"Selected similarities: {scores.tolist()}")
|
| 221 |
+
|
| 222 |
+
return indices.cpu(), scores.cpu()
|
| 223 |
+
except Exception as e:
|
| 224 |
+
logging.error(f"Error in search: {str(e)}")
|
| 225 |
+
raise
|
| 226 |
|
| 227 |
class RAGPipeline:
|
| 228 |
def __init__(self, data_folder: str, k: int = 5):
|