Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| from datasets import load_dataset | |
| from pathlib import Path | |
| import numpy as np | |
| import faiss | |
| import bm25s | |
| from src.fireworks.inference import create_client | |
| from src.config import EMBEDDING_MODEL | |
| _FILE_PATH = Path(__file__).parents[2] | |
| def load_amazon_raw_product_data() -> pd.DataFrame: | |
| ds = load_dataset("ckandemir/amazon-products") | |
| df = ds["train"].to_pandas() | |
| return df | |
| def load_clean_amazon_product_data() -> pd.DataFrame: | |
| return pd.read_parquet(_FILE_PATH / "data" / "amazon_products.parquet") | |
| def prepare_amazon_product_data(df: pd.DataFrame) -> pd.DataFrame: | |
| """ | |
| Data preparation for Amazon products. | |
| Args: | |
| df: DataFrame with 'Product Name', 'Category', 'Description' columns | |
| Returns: | |
| DataFrame | |
| """ | |
| # Full text is combination of Category + Description | |
| df.loc[:, "FullText"] = ( | |
| df["Product Name"] + " | " + df["Category"] + " | " + df["Description"] | |
| ) | |
| df.loc[:, "FullText"] = df.FullText.str.lower().str.strip().str.replace("\n", " ") | |
| df[["MainCategory", "SecondaryCategory", "TertiaryCategory"]] = df[ | |
| "Category" | |
| ].str.split(r" \| ", n=2, expand=True, regex=True) | |
| df = df.dropna(subset=["MainCategory", "SecondaryCategory"]) | |
| # Drop dupes | |
| df = df.drop_duplicates(subset=["FullText"]) | |
| # Downsample where MainCategory == Toys and Games to 650 since in raw data its over 70% of data | |
| df_non_toys = df[df["MainCategory"] != "Toys & Games"] | |
| df_toys = df[df["MainCategory"] == "Toys & Games"] | |
| df_toys = df_toys.sample(n=650, random_state=42) | |
| df = pd.concat([df_non_toys, df_toys]) | |
| # Filter to only top 5 MainCategories | |
| df = df[df["MainCategory"].isin(df["MainCategory"].value_counts().index[:5])] | |
| print( | |
| f"Prepared dataset with {len(df)} products with \n Count of MainCategories: {df['MainCategory'].value_counts()}" | |
| ) | |
| return df.loc[ | |
| :, | |
| [ | |
| "Product Name", | |
| "Description", | |
| "MainCategory", | |
| "SecondaryCategory", | |
| "TertiaryCategory", | |
| "FullText", | |
| ], | |
| ] | |
| def save_as_parquet(df: pd.DataFrame): | |
| """ | |
| Save DataFrame to parquet file. | |
| """ | |
| df.to_parquet(_FILE_PATH / "data" / "amazon_products.parquet", index=False) | |
| print(f"Saved to {_FILE_PATH / 'data' / 'amazon_products.parquet'}") | |
| def create_faiss_index(df: pd.DataFrame, batch_size: int = 100): | |
| """ | |
| Create FAISS index from product data using Fireworks AI embeddings. | |
| Args: | |
| df: DataFrame with 'FullText' column to embed | |
| batch_size: Number of texts to embed in each API call | |
| Returns: | |
| Tuple of (faiss_index, embeddings_array) | |
| """ | |
| client = create_client() | |
| print(f"Generating embeddings for {len(df)} products...") | |
| all_embeddings = [] | |
| texts = df["FullText"].tolist() | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i : i + batch_size] | |
| print( | |
| f"Processing batch {i // batch_size + 1}/{(len(texts) + batch_size - 1) // batch_size}" | |
| ) | |
| response = client.embeddings.create(model=EMBEDDING_MODEL, input=batch) | |
| batch_embeddings = [item.embedding for item in response.data] | |
| all_embeddings.extend(batch_embeddings) | |
| embeddings_array = np.array(all_embeddings, dtype=np.float32) | |
| dimension = embeddings_array.shape[1] | |
| index = faiss.IndexFlatL2( | |
| dimension | |
| ) # L2 distance for cosine similarity after normalization | |
| # Normalize embeddings for cosine similarity | |
| faiss.normalize_L2(embeddings_array) | |
| index.add(embeddings_array) | |
| print(f"Created FAISS index with {index.ntotal} vectors of dimension {dimension}") | |
| faiss.write_index(index, str(_FILE_PATH / "data" / "faiss_index.bin")) | |
| np.save(_FILE_PATH / "data" / "embeddings.npy", embeddings_array) | |
| print(f"Saved FAISS index to {_FILE_PATH / 'data' / 'faiss_index.bin'}") | |
| print(f"Saved embeddings to {_FILE_PATH / 'data' / 'embeddings.npy'}") | |
| return index, embeddings_array | |
| def load_faiss_index(): | |
| """ | |
| Load pre-computed FAISS index and embeddings from disk. | |
| Returns: | |
| Tuple of (faiss_index, embeddings_array) | |
| """ | |
| index = faiss.read_index(str(_FILE_PATH / "data" / "faiss_index.bin")) | |
| embeddings = np.load(_FILE_PATH / "data" / "embeddings.npy") | |
| print(f"Loaded FAISS index with {index.ntotal} vectors") | |
| return index, embeddings | |
| def create_bm25_index(df: pd.DataFrame): | |
| """ | |
| Create BM25 index from product data for lexical search. | |
| Args: | |
| df: DataFrame with 'FullText' column to index | |
| Returns: | |
| BM25 index object | |
| """ | |
| print(f"Creating BM25 index for {len(df)} products...") | |
| corpus = df["FullText"].tolist() | |
| corpus_tokens = bm25s.tokenize(corpus, stopwords="en") | |
| retriever = bm25s.BM25() | |
| retriever.index(corpus_tokens) | |
| retriever.save(_FILE_PATH / "data" / "bm25_index") | |
| print(f"Saved BM25 index to {_FILE_PATH / 'data' / 'bm25_index'}") | |
| return retriever | |
| def load_bm25_index(): | |
| """ | |
| Load pre-computed BM25 index from disk. | |
| Returns: | |
| BM25 index object | |
| """ | |
| retriever = bm25s.BM25.load(_FILE_PATH / "data" / "bm25_index", load_corpus=False) | |
| print("Loaded BM25 index") | |
| return retriever | |
| if __name__ == "__main__": | |
| _df = load_amazon_raw_product_data() | |
| _df = prepare_amazon_product_data(_df) | |
| save_as_parquet(_df) | |
| create_bm25_index(_df) | |
| create_faiss_index(_df) | |