NBA_Analysis / vector_db.py
shekkari21's picture
Add NBA analysis project files
ddabbe4
"""
Vector database manager for NBA data using ChromaDB and sentence-transformers.
"""
import os
import pandas as pd
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Optional
import json
class NBAVectorDB:
"""
Manages vector embeddings and semantic search for NBA data.
Uses sentence-transformers for embeddings and ChromaDB for storage.
"""
def __init__(self, csv_path: str, collection_name: str = "nba_data", persist_directory: str = "./chroma_db"):
"""
Initialize the vector database.
Args:
csv_path: Path to the NBA CSV file
collection_name: Name of the ChromaDB collection
persist_directory: Directory to persist the vector database
"""
self.csv_path = csv_path
self.collection_name = collection_name
self.persist_directory = persist_directory
# Initialize embedding model (open-source, runs locally)
# Using all-MiniLM-L6-v2: fast, good quality, 384 dimensions
print("Loading embedding model...")
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print("Embedding model loaded!")
# Initialize ChromaDB client
os.makedirs(persist_directory, exist_ok=True)
self.client = chromadb.PersistentClient(
path=persist_directory,
settings=Settings(anonymized_telemetry=False)
)
# Get or create collection
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"description": "NBA 2024-25 season data"}
)
# Check if collection is empty and needs indexing
if self.collection.count() == 0:
print("Vector database is empty. Indexing CSV data...")
self._index_csv()
else:
print(f"Vector database loaded with {self.collection.count()} records")
def _create_text_representation(self, row: pd.Series) -> str:
"""
Convert a DataFrame row to a text representation for embedding.
Args:
row: A pandas Series representing a row
Returns:
str: Text representation of the row
"""
# Create a natural language description of the row
parts = []
if 'Player' in row:
parts.append(f"Player: {row['Player']}")
if 'Tm' in row:
parts.append(f"Team: {row['Tm']}")
if 'Opp' in row:
parts.append(f"Opponent: {row['Opp']}")
if 'Res' in row:
parts.append(f"Result: {'Win' if row['Res'] == 'W' else 'Loss'}")
if 'PTS' in row and pd.notna(row['PTS']):
parts.append(f"Points: {row['PTS']}")
if 'AST' in row and pd.notna(row['AST']):
parts.append(f"Assists: {row['AST']}")
if 'TRB' in row and pd.notna(row['TRB']):
parts.append(f"Rebounds: {row['TRB']}")
if 'FG%' in row and pd.notna(row['FG%']):
parts.append(f"Field Goal Percentage: {row['FG%']:.3f}")
if '3P%' in row and pd.notna(row['3P%']):
parts.append(f"Three Point Percentage: {row['3P%']:.3f}")
if 'Data' in row:
parts.append(f"Date: {row['Data']}")
return ". ".join(parts)
def _index_csv(self):
"""
Read CSV file, create embeddings, and store in ChromaDB.
"""
print(f"Reading CSV from {self.csv_path}...")
df = pd.read_csv(self.csv_path)
print(f"Creating embeddings for {len(df)} records...")
texts = []
metadatas = []
ids = []
# Process in batches for efficiency
batch_size = 100
total_batches = (len(df) + batch_size - 1) // batch_size
for batch_idx in range(0, len(df), batch_size):
batch_df = df.iloc[batch_idx:batch_idx + batch_size]
batch_num = (batch_idx // batch_size) + 1
batch_texts = []
batch_metadatas = []
batch_ids = []
for idx, row in batch_df.iterrows():
# Create text representation
text = self._create_text_representation(row)
batch_texts.append(text)
# Store metadata (original row data as JSON)
metadata = {
'row_index': int(idx),
'player': str(row.get('Player', '')),
'team': str(row.get('Tm', '')),
'opponent': str(row.get('Opp', '')),
'result': str(row.get('Res', '')),
'points': float(row.get('PTS', 0)) if pd.notna(row.get('PTS')) else 0.0,
'date': str(row.get('Data', '')),
}
batch_metadatas.append(metadata)
batch_ids.append(f"row_{idx}")
# Generate embeddings for this batch
print(f"Processing batch {batch_num}/{total_batches} ({len(batch_texts)} records)...")
embeddings = self.embedding_model.encode(
batch_texts,
show_progress_bar=False,
convert_to_numpy=True
).tolist()
# Add to ChromaDB
self.collection.add(
embeddings=embeddings,
documents=batch_texts,
metadatas=batch_metadatas,
ids=batch_ids
)
texts.extend(batch_texts)
metadatas.extend(batch_metadatas)
ids.extend(batch_ids)
print(f"Successfully indexed {len(df)} records in vector database!")
def search(self, query: str, n_results: int = 10) -> List[Dict]:
"""
Perform semantic search on the NBA data.
Args:
query: Natural language query
n_results: Number of results to return
Returns:
List of dictionaries containing search results with metadata
"""
# Generate embedding for the query
query_embedding = self.embedding_model.encode(
query,
convert_to_numpy=True
).tolist()
# Search in ChromaDB
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
include=['documents', 'metadatas', 'distances']
)
# Format results
formatted_results = []
if results['ids'] and len(results['ids'][0]) > 0:
for i in range(len(results['ids'][0])):
formatted_results.append({
'id': results['ids'][0][i],
'document': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i],
'similarity': 1 - results['distances'][0][i] # Convert distance to similarity
})
return formatted_results
def get_original_row(self, row_index: int) -> Optional[pd.Series]:
"""
Retrieve the original CSV row by index.
Args:
row_index: Index of the row in the original CSV
Returns:
pandas Series or None if not found
"""
try:
df = pd.read_csv(self.csv_path)
if 0 <= row_index < len(df):
return df.iloc[row_index]
except Exception as e:
print(f"Error retrieving row {row_index}: {e}")
return None
# Global instance (will be initialized when needed)
_vector_db_instance: Optional[NBAVectorDB] = None
def get_vector_db(csv_path: str) -> NBAVectorDB:
"""
Get or create the global vector database instance.
Args:
csv_path: Path to the CSV file
Returns:
NBAVectorDB instance
"""
global _vector_db_instance
if _vector_db_instance is None or _vector_db_instance.csv_path != csv_path:
_vector_db_instance = NBAVectorDB(csv_path)
return _vector_db_instance