RobertoBarrosoLuque
Filter to only top 5 categories
099c385
raw
history blame
5.5 kB
import pandas as pd
from datasets import load_dataset
from pathlib import Path
import numpy as np
import faiss
import bm25s
from src.fireworks.inference import create_client
from src.config import EMBEDDING_MODEL
_FILE_PATH = Path(__file__).parents[2]
def load_amazon_raw_product_data() -> pd.DataFrame:
ds = load_dataset("ckandemir/amazon-products")
df = ds["train"].to_pandas()
return df
def load_clean_amazon_product_data() -> pd.DataFrame:
return pd.read_parquet(_FILE_PATH / "data" / "amazon_products.parquet")
def prepare_amazon_product_data(df: pd.DataFrame) -> pd.DataFrame:
"""
Data preparation for Amazon products.
Args:
df: DataFrame with 'Product Name', 'Category', 'Description' columns
Returns:
DataFrame
"""
# Full text is combination of Category + Description
df.loc[:, "FullText"] = (
df["Product Name"] + " | " + df["Category"] + " | " + df["Description"]
)
df.loc[:, "FullText"] = df.FullText.str.lower().str.strip().str.replace("\n", " ")
df[["MainCategory", "SecondaryCategory", "TertiaryCategory"]] = df[
"Category"
].str.split(r" \| ", n=2, expand=True, regex=True)
df = df.dropna(subset=["MainCategory", "SecondaryCategory"])
# Drop dupes
df = df.drop_duplicates(subset=["FullText"])
# Downsample where MainCategory == Toys and Games to 650 since in raw data its over 70% of data
df_non_toys = df[df["MainCategory"] != "Toys & Games"]
df_toys = df[df["MainCategory"] == "Toys & Games"]
df_toys = df_toys.sample(n=650, random_state=42)
df = pd.concat([df_non_toys, df_toys])
# Filter to only top 5 MainCategories
df = df[df["MainCategory"].isin(df["MainCategory"].value_counts().index[:5])]
print(
f"Prepared dataset with {len(df)} products with \n Count of MainCategories: {df['MainCategory'].value_counts()}"
)
return df.loc[
:,
[
"Product Name",
"Description",
"MainCategory",
"SecondaryCategory",
"TertiaryCategory",
"FullText",
],
]
def save_as_parquet(df: pd.DataFrame):
"""
Save DataFrame to parquet file.
"""
df.to_parquet(_FILE_PATH / "data" / "amazon_products.parquet", index=False)
print(f"Saved to {_FILE_PATH / 'data' / 'amazon_products.parquet'}")
def create_faiss_index(df: pd.DataFrame, batch_size: int = 100):
"""
Create FAISS index from product data using Fireworks AI embeddings.
Args:
df: DataFrame with 'FullText' column to embed
batch_size: Number of texts to embed in each API call
Returns:
Tuple of (faiss_index, embeddings_array)
"""
client = create_client()
print(f"Generating embeddings for {len(df)} products...")
all_embeddings = []
texts = df["FullText"].tolist()
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
print(
f"Processing batch {i // batch_size + 1}/{(len(texts) + batch_size - 1) // batch_size}"
)
response = client.embeddings.create(model=EMBEDDING_MODEL, input=batch)
batch_embeddings = [item.embedding for item in response.data]
all_embeddings.extend(batch_embeddings)
embeddings_array = np.array(all_embeddings, dtype=np.float32)
dimension = embeddings_array.shape[1]
index = faiss.IndexFlatL2(
dimension
) # L2 distance for cosine similarity after normalization
# Normalize embeddings for cosine similarity
faiss.normalize_L2(embeddings_array)
index.add(embeddings_array)
print(f"Created FAISS index with {index.ntotal} vectors of dimension {dimension}")
faiss.write_index(index, str(_FILE_PATH / "data" / "faiss_index.bin"))
np.save(_FILE_PATH / "data" / "embeddings.npy", embeddings_array)
print(f"Saved FAISS index to {_FILE_PATH / 'data' / 'faiss_index.bin'}")
print(f"Saved embeddings to {_FILE_PATH / 'data' / 'embeddings.npy'}")
return index, embeddings_array
def load_faiss_index():
"""
Load pre-computed FAISS index and embeddings from disk.
Returns:
Tuple of (faiss_index, embeddings_array)
"""
index = faiss.read_index(str(_FILE_PATH / "data" / "faiss_index.bin"))
embeddings = np.load(_FILE_PATH / "data" / "embeddings.npy")
print(f"Loaded FAISS index with {index.ntotal} vectors")
return index, embeddings
def create_bm25_index(df: pd.DataFrame):
"""
Create BM25 index from product data for lexical search.
Args:
df: DataFrame with 'FullText' column to index
Returns:
BM25 index object
"""
print(f"Creating BM25 index for {len(df)} products...")
corpus = df["FullText"].tolist()
corpus_tokens = bm25s.tokenize(corpus, stopwords="en")
retriever = bm25s.BM25()
retriever.index(corpus_tokens)
retriever.save(_FILE_PATH / "data" / "bm25_index")
print(f"Saved BM25 index to {_FILE_PATH / 'data' / 'bm25_index'}")
return retriever
def load_bm25_index():
"""
Load pre-computed BM25 index from disk.
Returns:
BM25 index object
"""
retriever = bm25s.BM25.load(_FILE_PATH / "data" / "bm25_index", load_corpus=False)
print("Loaded BM25 index")
return retriever
if __name__ == "__main__":
_df = load_amazon_raw_product_data()
_df = prepare_amazon_product_data(_df)
save_as_parquet(_df)
create_bm25_index(_df)
create_faiss_index(_df)