Spaces:
Running
Running
File size: 5,502 Bytes
2f3a721 73fd05b 34d08ee 2f3a721 099c385 2f3a721 73fd05b 2f3a721 73fd05b 2f3a721 73fd05b 2f3a721 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import pandas as pd
from datasets import load_dataset
from pathlib import Path
import numpy as np
import faiss
import bm25s
from src.fireworks.inference import create_client
from src.config import EMBEDDING_MODEL
_FILE_PATH = Path(__file__).parents[2]
def load_amazon_raw_product_data() -> pd.DataFrame:
ds = load_dataset("ckandemir/amazon-products")
df = ds["train"].to_pandas()
return df
def load_clean_amazon_product_data() -> pd.DataFrame:
return pd.read_parquet(_FILE_PATH / "data" / "amazon_products.parquet")
def prepare_amazon_product_data(df: pd.DataFrame) -> pd.DataFrame:
"""
Data preparation for Amazon products.
Args:
df: DataFrame with 'Product Name', 'Category', 'Description' columns
Returns:
DataFrame
"""
# Full text is combination of Category + Description
df.loc[:, "FullText"] = (
df["Product Name"] + " | " + df["Category"] + " | " + df["Description"]
)
df.loc[:, "FullText"] = df.FullText.str.lower().str.strip().str.replace("\n", " ")
df[["MainCategory", "SecondaryCategory", "TertiaryCategory"]] = df[
"Category"
].str.split(r" \| ", n=2, expand=True, regex=True)
df = df.dropna(subset=["MainCategory", "SecondaryCategory"])
# Drop dupes
df = df.drop_duplicates(subset=["FullText"])
# Downsample where MainCategory == Toys and Games to 650 since in raw data its over 70% of data
df_non_toys = df[df["MainCategory"] != "Toys & Games"]
df_toys = df[df["MainCategory"] == "Toys & Games"]
df_toys = df_toys.sample(n=650, random_state=42)
df = pd.concat([df_non_toys, df_toys])
# Filter to only top 5 MainCategories
df = df[df["MainCategory"].isin(df["MainCategory"].value_counts().index[:5])]
print(
f"Prepared dataset with {len(df)} products with \n Count of MainCategories: {df['MainCategory'].value_counts()}"
)
return df.loc[
:,
[
"Product Name",
"Description",
"MainCategory",
"SecondaryCategory",
"TertiaryCategory",
"FullText",
],
]
def save_as_parquet(df: pd.DataFrame):
"""
Save DataFrame to parquet file.
"""
df.to_parquet(_FILE_PATH / "data" / "amazon_products.parquet", index=False)
print(f"Saved to {_FILE_PATH / 'data' / 'amazon_products.parquet'}")
def create_faiss_index(df: pd.DataFrame, batch_size: int = 100):
"""
Create FAISS index from product data using Fireworks AI embeddings.
Args:
df: DataFrame with 'FullText' column to embed
batch_size: Number of texts to embed in each API call
Returns:
Tuple of (faiss_index, embeddings_array)
"""
client = create_client()
print(f"Generating embeddings for {len(df)} products...")
all_embeddings = []
texts = df["FullText"].tolist()
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
print(
f"Processing batch {i // batch_size + 1}/{(len(texts) + batch_size - 1) // batch_size}"
)
response = client.embeddings.create(model=EMBEDDING_MODEL, input=batch)
batch_embeddings = [item.embedding for item in response.data]
all_embeddings.extend(batch_embeddings)
embeddings_array = np.array(all_embeddings, dtype=np.float32)
dimension = embeddings_array.shape[1]
index = faiss.IndexFlatL2(
dimension
) # L2 distance for cosine similarity after normalization
# Normalize embeddings for cosine similarity
faiss.normalize_L2(embeddings_array)
index.add(embeddings_array)
print(f"Created FAISS index with {index.ntotal} vectors of dimension {dimension}")
faiss.write_index(index, str(_FILE_PATH / "data" / "faiss_index.bin"))
np.save(_FILE_PATH / "data" / "embeddings.npy", embeddings_array)
print(f"Saved FAISS index to {_FILE_PATH / 'data' / 'faiss_index.bin'}")
print(f"Saved embeddings to {_FILE_PATH / 'data' / 'embeddings.npy'}")
return index, embeddings_array
def load_faiss_index():
"""
Load pre-computed FAISS index and embeddings from disk.
Returns:
Tuple of (faiss_index, embeddings_array)
"""
index = faiss.read_index(str(_FILE_PATH / "data" / "faiss_index.bin"))
embeddings = np.load(_FILE_PATH / "data" / "embeddings.npy")
print(f"Loaded FAISS index with {index.ntotal} vectors")
return index, embeddings
def create_bm25_index(df: pd.DataFrame):
"""
Create BM25 index from product data for lexical search.
Args:
df: DataFrame with 'FullText' column to index
Returns:
BM25 index object
"""
print(f"Creating BM25 index for {len(df)} products...")
corpus = df["FullText"].tolist()
corpus_tokens = bm25s.tokenize(corpus, stopwords="en")
retriever = bm25s.BM25()
retriever.index(corpus_tokens)
retriever.save(_FILE_PATH / "data" / "bm25_index")
print(f"Saved BM25 index to {_FILE_PATH / 'data' / 'bm25_index'}")
return retriever
def load_bm25_index():
"""
Load pre-computed BM25 index from disk.
Returns:
BM25 index object
"""
retriever = bm25s.BM25.load(_FILE_PATH / "data" / "bm25_index", load_corpus=False)
print("Loaded BM25 index")
return retriever
if __name__ == "__main__":
_df = load_amazon_raw_product_data()
_df = prepare_amazon_product_data(_df)
save_as_parquet(_df)
create_bm25_index(_df)
create_faiss_index(_df)
|