Spaces:
Sleeping
Sleeping
File size: 8,312 Bytes
ddabbe4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 |
"""
Vector database manager for NBA data using ChromaDB and sentence-transformers.
"""
import os
import pandas as pd
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Optional
import json
class NBAVectorDB:
"""
Manages vector embeddings and semantic search for NBA data.
Uses sentence-transformers for embeddings and ChromaDB for storage.
"""
def __init__(self, csv_path: str, collection_name: str = "nba_data", persist_directory: str = "./chroma_db"):
"""
Initialize the vector database.
Args:
csv_path: Path to the NBA CSV file
collection_name: Name of the ChromaDB collection
persist_directory: Directory to persist the vector database
"""
self.csv_path = csv_path
self.collection_name = collection_name
self.persist_directory = persist_directory
# Initialize embedding model (open-source, runs locally)
# Using all-MiniLM-L6-v2: fast, good quality, 384 dimensions
print("Loading embedding model...")
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print("Embedding model loaded!")
# Initialize ChromaDB client
os.makedirs(persist_directory, exist_ok=True)
self.client = chromadb.PersistentClient(
path=persist_directory,
settings=Settings(anonymized_telemetry=False)
)
# Get or create collection
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"description": "NBA 2024-25 season data"}
)
# Check if collection is empty and needs indexing
if self.collection.count() == 0:
print("Vector database is empty. Indexing CSV data...")
self._index_csv()
else:
print(f"Vector database loaded with {self.collection.count()} records")
def _create_text_representation(self, row: pd.Series) -> str:
"""
Convert a DataFrame row to a text representation for embedding.
Args:
row: A pandas Series representing a row
Returns:
str: Text representation of the row
"""
# Create a natural language description of the row
parts = []
if 'Player' in row:
parts.append(f"Player: {row['Player']}")
if 'Tm' in row:
parts.append(f"Team: {row['Tm']}")
if 'Opp' in row:
parts.append(f"Opponent: {row['Opp']}")
if 'Res' in row:
parts.append(f"Result: {'Win' if row['Res'] == 'W' else 'Loss'}")
if 'PTS' in row and pd.notna(row['PTS']):
parts.append(f"Points: {row['PTS']}")
if 'AST' in row and pd.notna(row['AST']):
parts.append(f"Assists: {row['AST']}")
if 'TRB' in row and pd.notna(row['TRB']):
parts.append(f"Rebounds: {row['TRB']}")
if 'FG%' in row and pd.notna(row['FG%']):
parts.append(f"Field Goal Percentage: {row['FG%']:.3f}")
if '3P%' in row and pd.notna(row['3P%']):
parts.append(f"Three Point Percentage: {row['3P%']:.3f}")
if 'Data' in row:
parts.append(f"Date: {row['Data']}")
return ". ".join(parts)
def _index_csv(self):
"""
Read CSV file, create embeddings, and store in ChromaDB.
"""
print(f"Reading CSV from {self.csv_path}...")
df = pd.read_csv(self.csv_path)
print(f"Creating embeddings for {len(df)} records...")
texts = []
metadatas = []
ids = []
# Process in batches for efficiency
batch_size = 100
total_batches = (len(df) + batch_size - 1) // batch_size
for batch_idx in range(0, len(df), batch_size):
batch_df = df.iloc[batch_idx:batch_idx + batch_size]
batch_num = (batch_idx // batch_size) + 1
batch_texts = []
batch_metadatas = []
batch_ids = []
for idx, row in batch_df.iterrows():
# Create text representation
text = self._create_text_representation(row)
batch_texts.append(text)
# Store metadata (original row data as JSON)
metadata = {
'row_index': int(idx),
'player': str(row.get('Player', '')),
'team': str(row.get('Tm', '')),
'opponent': str(row.get('Opp', '')),
'result': str(row.get('Res', '')),
'points': float(row.get('PTS', 0)) if pd.notna(row.get('PTS')) else 0.0,
'date': str(row.get('Data', '')),
}
batch_metadatas.append(metadata)
batch_ids.append(f"row_{idx}")
# Generate embeddings for this batch
print(f"Processing batch {batch_num}/{total_batches} ({len(batch_texts)} records)...")
embeddings = self.embedding_model.encode(
batch_texts,
show_progress_bar=False,
convert_to_numpy=True
).tolist()
# Add to ChromaDB
self.collection.add(
embeddings=embeddings,
documents=batch_texts,
metadatas=batch_metadatas,
ids=batch_ids
)
texts.extend(batch_texts)
metadatas.extend(batch_metadatas)
ids.extend(batch_ids)
print(f"Successfully indexed {len(df)} records in vector database!")
def search(self, query: str, n_results: int = 10) -> List[Dict]:
"""
Perform semantic search on the NBA data.
Args:
query: Natural language query
n_results: Number of results to return
Returns:
List of dictionaries containing search results with metadata
"""
# Generate embedding for the query
query_embedding = self.embedding_model.encode(
query,
convert_to_numpy=True
).tolist()
# Search in ChromaDB
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
include=['documents', 'metadatas', 'distances']
)
# Format results
formatted_results = []
if results['ids'] and len(results['ids'][0]) > 0:
for i in range(len(results['ids'][0])):
formatted_results.append({
'id': results['ids'][0][i],
'document': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i],
'similarity': 1 - results['distances'][0][i] # Convert distance to similarity
})
return formatted_results
def get_original_row(self, row_index: int) -> Optional[pd.Series]:
"""
Retrieve the original CSV row by index.
Args:
row_index: Index of the row in the original CSV
Returns:
pandas Series or None if not found
"""
try:
df = pd.read_csv(self.csv_path)
if 0 <= row_index < len(df):
return df.iloc[row_index]
except Exception as e:
print(f"Error retrieving row {row_index}: {e}")
return None
# Global instance (will be initialized when needed)
_vector_db_instance: Optional[NBAVectorDB] = None
def get_vector_db(csv_path: str) -> NBAVectorDB:
"""
Get or create the global vector database instance.
Args:
csv_path: Path to the CSV file
Returns:
NBAVectorDB instance
"""
global _vector_db_instance
if _vector_db_instance is None or _vector_db_instance.csv_path != csv_path:
_vector_db_instance = NBAVectorDB(csv_path)
return _vector_db_instance
|