Spaces:
Runtime error
Runtime error
| import os | |
| import warnings | |
| import logging | |
| import sys | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Callable, Dict, Optional, Any | |
| import glob | |
| from tqdm import tqdm | |
| import pickle | |
| import torch.nn.functional as F | |
| from llama_cpp import Llama | |
| import streamlit as st | |
| import functools | |
| from datetime import datetime | |
| import re | |
| import time | |
| import requests | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| # Force CPU device | |
| torch.device('cpu') | |
| # Create necessary directories | |
| for directory in ['models', 'ESPN_data', 'embeddings_cache']: | |
| os.makedirs(directory, exist_ok=True) | |
| # Logging configuration | |
| LOGGING_CONFIG = { | |
| 'enabled': True, | |
| 'functions': { | |
| 'encode': True, | |
| 'store_embeddings': True, | |
| 'search': True, | |
| 'load_and_process_csvs': True, | |
| 'process_query': True | |
| } | |
| } | |
| def download_file_with_progress(url: str, filename: str): | |
| """Download a file with progress bar using requests""" | |
| response = requests.get(url, stream=True) | |
| total_size = int(response.headers.get('content-length', 0)) | |
| with open(filename, 'wb') as file, tqdm( | |
| desc=filename, | |
| total=total_size, | |
| unit='iB', | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as progress_bar: | |
| for data in response.iter_content(chunk_size=1024): | |
| size = file.write(data) | |
| progress_bar.update(size) | |
| def log_function(func: Callable) -> Callable: | |
| """Decorator to log function inputs and outputs""" | |
| def wrapper(*args, **kwargs): | |
| if not LOGGING_CONFIG['enabled'] or not LOGGING_CONFIG['functions'].get(func.__name__, False): | |
| return func(*args, **kwargs) | |
| if args and hasattr(args[0], '__class__'): | |
| class_name = args[0].__class__.__name__ | |
| else: | |
| class_name = func.__module__ | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') | |
| log_args = args[1:] if class_name != func.__module__ else args | |
| def format_arg(arg): | |
| if isinstance(arg, torch.Tensor): | |
| return f"Tensor(shape={list(arg.shape)}, device={arg.device})" | |
| elif isinstance(arg, list): | |
| return f"List(len={len(arg)})" | |
| elif isinstance(arg, str) and len(arg) > 100: | |
| return f"String(len={len(arg)}): {arg[:100]}..." | |
| return arg | |
| formatted_args = [format_arg(arg) for arg in log_args] | |
| formatted_kwargs = {k: format_arg(v) for k, v in kwargs.items()} | |
| print(f"\n{'='*80}") | |
| print(f"[{timestamp}] FUNCTION CALL: {class_name}.{func.__name__}") | |
| print(f"INPUTS:") | |
| print(f" args: {formatted_args}") | |
| print(f" kwargs: {formatted_kwargs}") | |
| result = func(*args, **kwargs) | |
| formatted_result = format_arg(result) | |
| print(f"OUTPUT:") | |
| print(f" {formatted_result}") | |
| print(f"{'='*80}\n") | |
| return result | |
| return wrapper | |
| def check_environment(): | |
| """Check if the environment is properly set up""" | |
| try: | |
| import numpy as np | |
| import torch | |
| import sentence_transformers | |
| import llama_cpp | |
| return True | |
| except ImportError as e: | |
| st.error(f"Missing required package: {str(e)}") | |
| st.stop() | |
| return False | |
| class SentenceTransformerRetriever: | |
| def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", cache_dir: str = "embeddings_cache"): | |
| self.device = torch.device("cpu") | |
| self.model_name = model_name | |
| self.cache_dir = cache_dir | |
| self.cache_file = "embeddings.pkl" | |
| self.doc_embeddings = None | |
| os.makedirs(cache_dir, exist_ok=True) | |
| self.model = self._load_model(model_name) | |
| def _load_model(_self, _model_name: str): | |
| try: | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| model = SentenceTransformer(_model_name, device="cpu") | |
| test_embedding = model.encode("test", convert_to_tensor=True) | |
| if not isinstance(test_embedding, torch.Tensor): | |
| raise ValueError("Model initialization failed") | |
| return model | |
| except Exception as e: | |
| logging.error(f"Error loading model: {str(e)}") | |
| raise | |
| def get_cache_path(self, data_folder: str = None) -> str: | |
| return os.path.join(self.cache_dir, self.cache_file) | |
| def save_cache(self, data_folder: str, cache_data: dict): | |
| try: | |
| cache_path = self.get_cache_path() | |
| if os.path.exists(cache_path): | |
| os.remove(cache_path) | |
| with open(cache_path, 'wb') as f: | |
| pickle.dump(cache_data, f) | |
| logging.info(f"Cache saved at: {cache_path}") | |
| except Exception as e: | |
| logging.error(f"Error saving cache: {str(e)}") | |
| raise | |
| def load_cache(_self, _data_folder: str = None) -> Optional[Dict]: | |
| try: | |
| cache_path = _self.get_cache_path() | |
| if os.path.exists(cache_path): | |
| with open(cache_path, 'rb') as f: | |
| logging.info(f"Loading cache from: {cache_path}") | |
| cache_data = pickle.load(f) | |
| if isinstance(cache_data, dict) and 'embeddings' in cache_data and 'documents' in cache_data: | |
| return cache_data | |
| logging.warning("Invalid cache format") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error loading cache: {str(e)}") | |
| return None | |
| def encode(self, texts: List[str], batch_size: int = 64) -> torch.Tensor: # Increased batch size | |
| try: | |
| # Show a Streamlit progress bar | |
| progress_text = "Processing documents..." | |
| progress_bar = st.progress(0) | |
| total_batches = len(texts) // batch_size + (1 if len(texts) % batch_size != 0 else 0) | |
| all_embeddings = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i:i + batch_size] | |
| batch_embeddings = self.model.encode( | |
| batch, | |
| convert_to_tensor=True, | |
| show_progress_bar=False # Disable tqdm progress bar | |
| ) | |
| all_embeddings.append(batch_embeddings) | |
| # Update progress | |
| progress = min((i + batch_size) / len(texts), 1.0) | |
| progress_bar.progress(progress) | |
| # Clear progress bar | |
| progress_bar.empty() | |
| # Concatenate all embeddings | |
| embeddings = torch.cat(all_embeddings, dim=0) | |
| return F.normalize(embeddings, p=2, dim=1) | |
| except Exception as e: | |
| logging.error(f"Error encoding texts: {str(e)}") | |
| raise | |
| def store_embeddings(self, embeddings: torch.Tensor): | |
| self.doc_embeddings = embeddings | |
| def search(self, query_embedding: torch.Tensor, k: int, documents: List[str]): | |
| try: | |
| if self.doc_embeddings is None: | |
| raise ValueError("No document embeddings stored!") | |
| similarities = F.cosine_similarity(query_embedding, self.doc_embeddings) | |
| k = min(k, len(documents)) | |
| scores, indices = torch.topk(similarities, k=k) | |
| logging.info(f"\nSimilarity Stats:") | |
| logging.info(f"Max similarity: {similarities.max().item():.4f}") | |
| logging.info(f"Mean similarity: {similarities.mean().item():.4f}") | |
| logging.info(f"Selected similarities: {scores.tolist()}") | |
| return indices.cpu(), scores.cpu() | |
| except Exception as e: | |
| logging.error(f"Error in search: {str(e)}") | |
| raise | |
| class RAGPipeline: | |
| def __init__(self, data_folder: str, k: int = 5): | |
| self.data_folder = data_folder | |
| self.k = k | |
| self.retriever = SentenceTransformerRetriever() | |
| self.documents = [] | |
| self.device = torch.device("cpu") | |
| # Change 1: Process documents first | |
| self.load_and_process_csvs() | |
| # Change 2: Simplified model path | |
| self.model_path = "mistral-7b-v0.1.Q4_K_M.gguf" | |
| self.llm = None | |
| # Change 3: Initialize model after documents are processed | |
| self._initialize_model() | |
| # Added caching decorator | |
| def _initialize_model(_self): | |
| try: | |
| if not os.path.exists(_self.model_path): | |
| direct_url = "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_K_M.gguf" | |
| download_file_with_progress(direct_url, _self.model_path) | |
| # Added better error handling | |
| if not os.path.exists(_self.model_path): | |
| raise FileNotFoundError(f"Model file {_self.model_path} not found after download attempts") | |
| # Added verbose mode for better debugging | |
| llm_config = { | |
| "n_ctx": 2048, | |
| "n_threads": 4, | |
| "n_batch": 512, | |
| "n_gpu_layers": 0, | |
| "verbose": True # Added this | |
| } | |
| _self.llm = Llama(model_path=_self.model_path, **llm_config) | |
| st.success("Model loaded successfully!") | |
| except Exception as e: | |
| # Added better error logging | |
| logging.error(f"Error initializing model: {str(e)}") | |
| st.error(f"Error initializing model: {str(e)}") | |
| raise | |
| def check_model_health(self): | |
| try: | |
| if self.llm is None: | |
| return False | |
| test_response = self.llm( | |
| "Test prompt", | |
| max_tokens=10, | |
| temperature=0.4, | |
| echo=False | |
| ) | |
| return isinstance(test_response, dict) and 'choices' in test_response | |
| except Exception: | |
| return False | |
| def load_and_process_csvs(_self): | |
| try: | |
| # Try loading from cache first | |
| cache_data = _self.retriever.load_cache(_self.data_folder) | |
| if cache_data is not None: | |
| _self.documents = cache_data['documents'] | |
| _self.retriever.store_embeddings(cache_data['embeddings']) | |
| st.success("Loaded documents from cache") | |
| return | |
| st.info("Processing documents... This may take a while.") | |
| csv_files = glob.glob(os.path.join(_self.data_folder, "*.csv")) | |
| if not csv_files: | |
| raise FileNotFoundError(f"No CSV files found in {_self.data_folder}") | |
| all_documents = [] | |
| total_files = len(csv_files) | |
| # Create a progress bar | |
| progress_bar = st.progress(0) | |
| for idx, csv_file in enumerate(csv_files): | |
| try: | |
| df = pd.read_csv(csv_file, low_memory=False) # Added low_memory=False | |
| texts = df.apply(lambda x: " ".join(x.astype(str)), axis=1).tolist() | |
| all_documents.extend(texts) | |
| # Update progress | |
| progress = (idx + 1) / total_files | |
| progress_bar.progress(progress) | |
| except Exception as e: | |
| logging.error(f"Error processing file {csv_file}: {e}") | |
| continue | |
| # Clear progress bar | |
| progress_bar.empty() | |
| if not all_documents: | |
| raise ValueError("No documents were successfully loaded") | |
| st.info(f"Processing {len(all_documents)} documents...") | |
| _self.documents = all_documents | |
| embeddings = _self.retriever.encode(all_documents) | |
| _self.retriever.store_embeddings(embeddings) | |
| # Save to cache | |
| cache_data = { | |
| 'embeddings': embeddings, | |
| 'documents': _self.documents | |
| } | |
| _self.retriever.save_cache(_self.data_folder, cache_data) | |
| st.success("Document processing complete!") | |
| except Exception as e: | |
| logging.error(f"Error in load_and_process_csvs: {str(e)}") | |
| raise | |
| def preprocess_query(self, query: str) -> str: | |
| query = query.lower().strip() | |
| query = re.sub(r'\s+', ' ', query) | |
| return query | |
| def postprocess_response(self, response: str) -> str: | |
| response = response.strip() | |
| response = re.sub(r'\s+', ' ', response) | |
| response = re.sub(r'\d{4}-\d{2}-\d{2}\s\d{2}:\d{2}:\d{2}(?:\+\d{2}:?\d{2})?', '', response) | |
| return response | |
| def process_query(self, query: str, placeholder) -> str: | |
| try: | |
| if self.llm is None: | |
| raise RuntimeError("LLM model not initialized") | |
| if self.retriever.model is None: | |
| raise RuntimeError("Sentence transformer model not initialized") | |
| query = self.preprocess_query(query) | |
| status = placeholder.empty() | |
| status.write("π Finding relevant information...") | |
| query_embedding = self.retriever.encode([query]) | |
| indices, scores = self.retriever.search(query_embedding, self.k, self.documents) | |
| logging.info("\nSearch Results:") | |
| for idx, score in zip(indices.tolist(), scores.tolist()): | |
| logging.info(f"Score: {score:.4f} | Document: {self.documents[idx][:100]}...") | |
| relevant_docs = [self.documents[idx] for idx in indices.tolist()] | |
| status.write("π Generating response...") | |
| context = "\n".join(relevant_docs) | |
| prompt = f"""Context information is below: | |
| {context} | |
| Given the context above, please answer the following question: | |
| {query} | |
| Guidelines: | |
| - If you cannot answer based on the context, say so politely | |
| - Keep the response concise and focused | |
| - Only include sports-related information | |
| - No dates or timestamps in the response | |
| - Use clear, natural language | |
| Answer:""" | |
| response_placeholder = placeholder.empty() | |
| try: | |
| response = self.llm( | |
| prompt, | |
| max_tokens=512, | |
| temperature=0.4, | |
| top_p=0.95, | |
| echo=False, | |
| stop=["Question:", "\n\n"] | |
| ) | |
| if response and 'choices' in response and len(response['choices']) > 0: | |
| generated_text = response['choices'][0].get('text', '').strip() | |
| if generated_text: | |
| final_response = self.postprocess_response(generated_text) | |
| response_placeholder.markdown(final_response) | |
| return final_response | |
| else: | |
| message = "No relevant answer found. Please try rephrasing your question." | |
| response_placeholder.warning(message) | |
| return message | |
| else: | |
| message = "Unable to generate response. Please try again." | |
| response_placeholder.warning(message) | |
| return message | |
| except Exception as e: | |
| logging.error(f"Generation error: {str(e)}") | |
| message = "Had some trouble generating the response. Please try again." | |
| response_placeholder.warning(message) | |
| return message | |
| except Exception as e: | |
| logging.error(f"Process error: {str(e)}") | |
| message = "Something went wrong. Please try again with a different question." | |
| placeholder.warning(message) | |
| return message | |
| def initialize_rag_pipeline(): | |
| """Initialize the RAG pipeline once""" | |
| try: | |
| data_folder = "ESPN_data" | |
| if not os.path.exists(data_folder): | |
| os.makedirs(data_folder, exist_ok=True) | |
| # Check for cache | |
| cache_path = os.path.join("embeddings_cache", "embeddings.pkl") | |
| if os.path.exists(cache_path): | |
| st.info("Found cached data. Loading...") | |
| else: | |
| st.warning("Initial setup may take several minutes...") | |
| rag = RAGPipeline(data_folder) | |
| return rag | |
| except Exception as e: | |
| logging.error(f"Pipeline initialization error: {str(e)}") | |
| st.error("Failed to initialize the system. Please check if all required files are present.") | |
| raise | |
| def main(): | |
| try: | |
| # Environment check | |
| if not check_environment(): | |
| return | |
| # Page config | |
| st.set_page_config( | |
| page_title="The Sport Chatbot", | |
| page_icon="π", | |
| layout="wide" | |
| ) | |
| # Improved CSS styling | |
| st.markdown(""" | |
| <style> | |
| .block-container { | |
| padding-top: 2rem; | |
| padding-bottom: 2rem; | |
| } | |
| .stTextInput > div > div > input { | |
| width: 100%; | |
| } | |
| .stButton > button { | |
| width: 200px; | |
| margin: 0 auto; | |
| display: block; | |
| background-color: #FF4B4B; | |
| color: white; | |
| border-radius: 5px; | |
| padding: 0.5rem 1rem; | |
| } | |
| .main-title { | |
| text-align: center; | |
| padding: 1rem 0; | |
| font-size: 3rem; | |
| color: #1F1F1F; | |
| } | |
| .sub-title { | |
| text-align: center; | |
| padding: 0.5rem 0; | |
| font-size: 1.5rem; | |
| color: #4F4F4F; | |
| } | |
| .description { | |
| text-align: center; | |
| color: #666666; | |
| padding: 0.5rem 0; | |
| font-size: 1.1rem; | |
| line-height: 1.6; | |
| margin-bottom: 1rem; | |
| } | |
| .stMarkdown { | |
| max-width: 100%; | |
| } | |
| .st-emotion-cache-16idsys p { | |
| font-size: 1.1rem; | |
| line-height: 1.6; | |
| } | |
| .main-content { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| padding: 0 1rem; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Header section | |
| st.markdown("<h1 class='main-title'>π The Sport Chatbot</h1>", unsafe_allow_html=True) | |
| st.markdown("<h3 class='sub-title'>Using ESPN API</h3>", unsafe_allow_html=True) | |
| st.markdown(""" | |
| <p class='description'> | |
| Hey there! π I can help you with information on Ice Hockey, Baseball, American Football, Soccer, and Basketball. | |
| With access to the ESPN API, I'm up to date with the latest details for these sports up until October 2024. | |
| </p> | |
| <p class='description'> | |
| Got any general questions? Feel free to askβI'll do my best to provide answers based on the information I've been trained on! | |
| </p> | |
| """, unsafe_allow_html=True) | |
| # Add spacing | |
| st.markdown("<br>", unsafe_allow_html=True) | |
| # Initialize the pipeline | |
| try: | |
| with st.spinner("Loading resources..."): | |
| rag = initialize_rag_pipeline() | |
| # Add a model health check | |
| if not rag.check_model_health(): | |
| st.error("Model initialization failed. Please try restarting the application.") | |
| return | |
| except Exception as e: | |
| logging.error(f"Initialization error: {str(e)}") | |
| st.error("Unable to initialize the system. Please check if all required files are present.") | |
| return | |
| # Create columns for layout with golden ratio | |
| col1, col2, col3 = st.columns([1, 6, 1]) | |
| with col2: | |
| # Query input with label styling | |
| query = st.text_input("What would you like to know about sports?") | |
| # Centered button | |
| if st.button("Get Answer"): | |
| if query: | |
| response_placeholder = st.empty() | |
| try: | |
| response = rag.process_query(query, response_placeholder) | |
| logging.info(f"Generated response: {response}") | |
| except Exception as e: | |
| logging.error(f"Query processing error: {str(e)}") | |
| response_placeholder.warning("Unable to process your question. Please try again.") | |
| else: | |
| st.warning("Please enter a question!") | |
| # Footer | |
| st.markdown("<br><br>", unsafe_allow_html=True) | |
| st.markdown("---") | |
| st.markdown(""" | |
| <p style='text-align: center; color: #666666; padding: 1rem 0;'> | |
| Powered by ESPN Data & Mistral AI π | |
| </p> | |
| """, unsafe_allow_html=True) | |
| except Exception as e: | |
| logging.error(f"Application error: {str(e)}") | |
| st.error("An unexpected error occurred. Please check the logs and try again.") | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except Exception as e: | |
| logging.error(f"Application error: {str(e)}") | |
| st.error("An unexpected error occurred. Please check the logs and try again.") |