Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -177,10 +177,35 @@ class SentenceTransformerRetriever:
|
|
| 177 |
return None
|
| 178 |
|
| 179 |
@log_function
|
| 180 |
-
def encode(self, texts: List[str], batch_size: int =
|
| 181 |
try:
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
return F.normalize(embeddings, p=2, dim=1)
|
|
|
|
| 184 |
except Exception as e:
|
| 185 |
logging.error(f"Error encoding texts: {str(e)}")
|
| 186 |
raise
|
|
@@ -274,38 +299,59 @@ class RAGPipeline:
|
|
| 274 |
@st.cache_data
|
| 275 |
def load_and_process_csvs(_self):
|
| 276 |
try:
|
|
|
|
| 277 |
cache_data = _self.retriever.load_cache(_self.data_folder)
|
| 278 |
if cache_data is not None:
|
| 279 |
_self.documents = cache_data['documents']
|
| 280 |
_self.retriever.store_embeddings(cache_data['embeddings'])
|
|
|
|
| 281 |
return
|
| 282 |
|
|
|
|
| 283 |
csv_files = glob.glob(os.path.join(_self.data_folder, "*.csv"))
|
| 284 |
if not csv_files:
|
| 285 |
raise FileNotFoundError(f"No CSV files found in {_self.data_folder}")
|
| 286 |
|
| 287 |
all_documents = []
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
try:
|
| 290 |
-
df = pd.read_csv(csv_file)
|
| 291 |
texts = df.apply(lambda x: " ".join(x.astype(str)), axis=1).tolist()
|
| 292 |
all_documents.extend(texts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
except Exception as e:
|
| 294 |
logging.error(f"Error processing file {csv_file}: {e}")
|
| 295 |
continue
|
| 296 |
|
|
|
|
|
|
|
|
|
|
| 297 |
if not all_documents:
|
| 298 |
raise ValueError("No documents were successfully loaded")
|
| 299 |
|
|
|
|
| 300 |
_self.documents = all_documents
|
| 301 |
embeddings = _self.retriever.encode(all_documents)
|
| 302 |
_self.retriever.store_embeddings(embeddings)
|
| 303 |
|
|
|
|
| 304 |
cache_data = {
|
| 305 |
'embeddings': embeddings,
|
| 306 |
'documents': _self.documents
|
| 307 |
}
|
| 308 |
_self.retriever.save_cache(_self.data_folder, cache_data)
|
|
|
|
|
|
|
|
|
|
| 309 |
except Exception as e:
|
| 310 |
logging.error(f"Error in load_and_process_csvs: {str(e)}")
|
| 311 |
raise
|
|
@@ -403,13 +449,20 @@ def initialize_rag_pipeline():
|
|
| 403 |
data_folder = "ESPN_data"
|
| 404 |
if not os.path.exists(data_folder):
|
| 405 |
os.makedirs(data_folder, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
rag = RAGPipeline(data_folder)
|
| 408 |
-
rag.load_and_process_csvs()
|
| 409 |
return rag
|
|
|
|
| 410 |
except Exception as e:
|
| 411 |
logging.error(f"Pipeline initialization error: {str(e)}")
|
| 412 |
-
st.error("Failed to initialize the system. Please check
|
| 413 |
raise
|
| 414 |
|
| 415 |
def main():
|
|
|
|
| 177 |
return None
|
| 178 |
|
| 179 |
@log_function
|
| 180 |
+
def encode(self, texts: List[str], batch_size: int = 64) -> torch.Tensor: # Increased batch size
|
| 181 |
try:
|
| 182 |
+
# Show a Streamlit progress bar
|
| 183 |
+
progress_text = "Processing documents..."
|
| 184 |
+
progress_bar = st.progress(0)
|
| 185 |
+
|
| 186 |
+
total_batches = len(texts) // batch_size + (1 if len(texts) % batch_size != 0 else 0)
|
| 187 |
+
all_embeddings = []
|
| 188 |
+
|
| 189 |
+
for i in range(0, len(texts), batch_size):
|
| 190 |
+
batch = texts[i:i + batch_size]
|
| 191 |
+
batch_embeddings = self.model.encode(
|
| 192 |
+
batch,
|
| 193 |
+
convert_to_tensor=True,
|
| 194 |
+
show_progress_bar=False # Disable tqdm progress bar
|
| 195 |
+
)
|
| 196 |
+
all_embeddings.append(batch_embeddings)
|
| 197 |
+
|
| 198 |
+
# Update progress
|
| 199 |
+
progress = min((i + batch_size) / len(texts), 1.0)
|
| 200 |
+
progress_bar.progress(progress)
|
| 201 |
+
|
| 202 |
+
# Clear progress bar
|
| 203 |
+
progress_bar.empty()
|
| 204 |
+
|
| 205 |
+
# Concatenate all embeddings
|
| 206 |
+
embeddings = torch.cat(all_embeddings, dim=0)
|
| 207 |
return F.normalize(embeddings, p=2, dim=1)
|
| 208 |
+
|
| 209 |
except Exception as e:
|
| 210 |
logging.error(f"Error encoding texts: {str(e)}")
|
| 211 |
raise
|
|
|
|
| 299 |
@st.cache_data
|
| 300 |
def load_and_process_csvs(_self):
|
| 301 |
try:
|
| 302 |
+
# Try loading from cache first
|
| 303 |
cache_data = _self.retriever.load_cache(_self.data_folder)
|
| 304 |
if cache_data is not None:
|
| 305 |
_self.documents = cache_data['documents']
|
| 306 |
_self.retriever.store_embeddings(cache_data['embeddings'])
|
| 307 |
+
st.success("Loaded documents from cache")
|
| 308 |
return
|
| 309 |
|
| 310 |
+
st.info("Processing documents... This may take a while.")
|
| 311 |
csv_files = glob.glob(os.path.join(_self.data_folder, "*.csv"))
|
| 312 |
if not csv_files:
|
| 313 |
raise FileNotFoundError(f"No CSV files found in {_self.data_folder}")
|
| 314 |
|
| 315 |
all_documents = []
|
| 316 |
+
total_files = len(csv_files)
|
| 317 |
+
|
| 318 |
+
# Create a progress bar
|
| 319 |
+
progress_bar = st.progress(0)
|
| 320 |
+
|
| 321 |
+
for idx, csv_file in enumerate(csv_files):
|
| 322 |
try:
|
| 323 |
+
df = pd.read_csv(csv_file, low_memory=False) # Added low_memory=False
|
| 324 |
texts = df.apply(lambda x: " ".join(x.astype(str)), axis=1).tolist()
|
| 325 |
all_documents.extend(texts)
|
| 326 |
+
|
| 327 |
+
# Update progress
|
| 328 |
+
progress = (idx + 1) / total_files
|
| 329 |
+
progress_bar.progress(progress)
|
| 330 |
+
|
| 331 |
except Exception as e:
|
| 332 |
logging.error(f"Error processing file {csv_file}: {e}")
|
| 333 |
continue
|
| 334 |
|
| 335 |
+
# Clear progress bar
|
| 336 |
+
progress_bar.empty()
|
| 337 |
+
|
| 338 |
if not all_documents:
|
| 339 |
raise ValueError("No documents were successfully loaded")
|
| 340 |
|
| 341 |
+
st.info(f"Processing {len(all_documents)} documents...")
|
| 342 |
_self.documents = all_documents
|
| 343 |
embeddings = _self.retriever.encode(all_documents)
|
| 344 |
_self.retriever.store_embeddings(embeddings)
|
| 345 |
|
| 346 |
+
# Save to cache
|
| 347 |
cache_data = {
|
| 348 |
'embeddings': embeddings,
|
| 349 |
'documents': _self.documents
|
| 350 |
}
|
| 351 |
_self.retriever.save_cache(_self.data_folder, cache_data)
|
| 352 |
+
|
| 353 |
+
st.success("Document processing complete!")
|
| 354 |
+
|
| 355 |
except Exception as e:
|
| 356 |
logging.error(f"Error in load_and_process_csvs: {str(e)}")
|
| 357 |
raise
|
|
|
|
| 449 |
data_folder = "ESPN_data"
|
| 450 |
if not os.path.exists(data_folder):
|
| 451 |
os.makedirs(data_folder, exist_ok=True)
|
| 452 |
+
|
| 453 |
+
# Check for cache
|
| 454 |
+
cache_path = os.path.join("embeddings_cache", "embeddings.pkl")
|
| 455 |
+
if os.path.exists(cache_path):
|
| 456 |
+
st.info("Found cached data. Loading...")
|
| 457 |
+
else:
|
| 458 |
+
st.warning("Initial setup may take several minutes...")
|
| 459 |
|
| 460 |
rag = RAGPipeline(data_folder)
|
|
|
|
| 461 |
return rag
|
| 462 |
+
|
| 463 |
except Exception as e:
|
| 464 |
logging.error(f"Pipeline initialization error: {str(e)}")
|
| 465 |
+
st.error("Failed to initialize the system. Please check if all required files are present.")
|
| 466 |
raise
|
| 467 |
|
| 468 |
def main():
|