Spaces:
Runtime error
Runtime error
| import os | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Callable | |
| import glob | |
| from tqdm import tqdm | |
| import pickle | |
| import torch.nn.functional as F | |
| import functools | |
| from datetime import datetime | |
| # Force CPU device | |
| torch.device('cpu') | |
| # Logging configuration | |
| LOGGING_CONFIG = { | |
| 'enabled': True, | |
| 'functions': { | |
| 'encode': True, | |
| 'store_embeddings': True, | |
| 'search': True, | |
| 'load_and_process_csvs': True | |
| } | |
| } | |
| def log_function(func: Callable) -> Callable: | |
| """Decorator to log function inputs and outputs""" | |
| def wrapper(*args, **kwargs): | |
| if not LOGGING_CONFIG['enabled'] or not LOGGING_CONFIG['functions'].get(func.__name__, False): | |
| return func(*args, **kwargs) | |
| if args and hasattr(args[0], '__class__'): | |
| class_name = args[0].__class__.__name__ | |
| else: | |
| class_name = func.__module__ | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') | |
| log_args = args[1:] if class_name != func.__module__ else args | |
| def format_arg(arg): | |
| if isinstance(arg, torch.Tensor): | |
| return f"Tensor(shape={list(arg.shape)}, device={arg.device})" | |
| elif isinstance(arg, list): | |
| return f"List(len={len(arg)})" | |
| elif isinstance(arg, str) and len(arg) > 100: | |
| return f"String(len={len(arg)}): {arg[:100]}..." | |
| return arg | |
| formatted_args = [format_arg(arg) for arg in log_args] | |
| formatted_kwargs = {k: format_arg(v) for k, v in kwargs.items()} | |
| print(f"\n{'='*80}") | |
| print(f"[{timestamp}] FUNCTION CALL: {class_name}.{func.__name__}") | |
| print(f"INPUTS:") | |
| print(f" args: {formatted_args}") | |
| print(f" kwargs: {formatted_kwargs}") | |
| result = func(*args, **kwargs) | |
| formatted_result = format_arg(result) | |
| print(f"OUTPUT:") | |
| print(f" {formatted_result}") | |
| print(f"{'='*80}\n") | |
| return result | |
| return wrapper | |
| class SentenceTransformerRetriever: | |
| def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"): | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| self.device = torch.device("cpu") | |
| self.model = SentenceTransformer(model_name, device="cpu") | |
| self.doc_embeddings = None | |
| self.cache_dir = cache_dir | |
| self.cache_file = "embeddings.pkl" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| def get_cache_path(self) -> str: | |
| return os.path.join(self.cache_dir, self.cache_file) | |
| def save_cache(self, cache_data: dict): | |
| cache_path = self.get_cache_path() | |
| if os.path.exists(cache_path): | |
| os.remove(cache_path) | |
| with open(cache_path, 'wb') as f: | |
| pickle.dump(cache_data, f) | |
| print(f"Cache saved at: {cache_path}") | |
| def load_cache(self) -> dict: | |
| cache_path = self.get_cache_path() | |
| if os.path.exists(cache_path): | |
| with open(cache_path, 'rb') as f: | |
| print(f"Loading cache from: {cache_path}") | |
| return pickle.load(f) | |
| return None | |
| def encode(self, texts: List[str], batch_size: int = 32) -> torch.Tensor: | |
| embeddings = self.model.encode(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True) | |
| return F.normalize(embeddings, p=2, dim=1) | |
| def store_embeddings(self, embeddings: torch.Tensor): | |
| self.doc_embeddings = embeddings | |
| def process_data(data_folder: str): | |
| retriever = SentenceTransformerRetriever() | |
| documents = [] | |
| # Check cache first | |
| cache_data = retriever.load_cache() | |
| if cache_data is not None: | |
| print("Using cached embeddings") | |
| return cache_data | |
| # Process CSV files | |
| csv_files = glob.glob(os.path.join(data_folder, "*.csv")) | |
| for csv_file in tqdm(csv_files, desc="Reading CSV files"): | |
| try: | |
| df = pd.read_csv(csv_file) | |
| texts = df.apply(lambda x: " ".join(x.astype(str)), axis=1).tolist() | |
| documents.extend(texts) | |
| except Exception as e: | |
| print(f"Error processing file {csv_file}: {e}") | |
| continue | |
| # Generate embeddings | |
| embeddings = retriever.encode(documents) | |
| # Save cache | |
| cache_data = { | |
| 'embeddings': embeddings, | |
| 'documents': documents | |
| } | |
| retriever.save_cache(cache_data) | |
| return cache_data | |
| if __name__ == "__main__": | |
| data_folder = "ESPN_data" | |
| process_data(data_folder) |