File size: 5,502 Bytes
2f3a721
 
 
 
 
73fd05b
34d08ee
2f3a721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
099c385
 
 
 
 
 
 
 
 
 
 
 
 
2f3a721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73fd05b
2f3a721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73fd05b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f3a721
 
 
 
 
73fd05b
2f3a721
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import pandas as pd
from datasets import load_dataset
from pathlib import Path
import numpy as np
import faiss
import bm25s
from src.fireworks.inference import create_client
from src.config import EMBEDDING_MODEL

_FILE_PATH = Path(__file__).parents[2]


def load_amazon_raw_product_data() -> pd.DataFrame:
    ds = load_dataset("ckandemir/amazon-products")
    df = ds["train"].to_pandas()
    return df


def load_clean_amazon_product_data() -> pd.DataFrame:
    return pd.read_parquet(_FILE_PATH / "data" / "amazon_products.parquet")


def prepare_amazon_product_data(df: pd.DataFrame) -> pd.DataFrame:
    """
    Data preparation for Amazon products.

    Args:
        df: DataFrame with 'Product Name', 'Category', 'Description' columns

    Returns:
        DataFrame
    """
    # Full text is combination of Category + Description
    df.loc[:, "FullText"] = (
        df["Product Name"] + " | " + df["Category"] + " | " + df["Description"]
    )
    df.loc[:, "FullText"] = df.FullText.str.lower().str.strip().str.replace("\n", " ")

    df[["MainCategory", "SecondaryCategory", "TertiaryCategory"]] = df[
        "Category"
    ].str.split(r" \| ", n=2, expand=True, regex=True)
    df = df.dropna(subset=["MainCategory", "SecondaryCategory"])

    # Drop dupes
    df = df.drop_duplicates(subset=["FullText"])

    # Downsample where MainCategory == Toys and Games to 650 since in raw data its over 70% of data
    df_non_toys = df[df["MainCategory"] != "Toys & Games"]
    df_toys = df[df["MainCategory"] == "Toys & Games"]
    df_toys = df_toys.sample(n=650, random_state=42)
    df = pd.concat([df_non_toys, df_toys])

    # Filter to only top 5 MainCategories
    df = df[df["MainCategory"].isin(df["MainCategory"].value_counts().index[:5])]

    print(
        f"Prepared dataset with {len(df)} products with \n Count of MainCategories: {df['MainCategory'].value_counts()}"
    )

    return df.loc[
        :,
        [
            "Product Name",
            "Description",
            "MainCategory",
            "SecondaryCategory",
            "TertiaryCategory",
            "FullText",
        ],
    ]


def save_as_parquet(df: pd.DataFrame):
    """
    Save DataFrame to parquet file.
    """
    df.to_parquet(_FILE_PATH / "data" / "amazon_products.parquet", index=False)
    print(f"Saved to {_FILE_PATH / 'data' / 'amazon_products.parquet'}")


def create_faiss_index(df: pd.DataFrame, batch_size: int = 100):
    """
    Create FAISS index from product data using Fireworks AI embeddings.

    Args:
        df: DataFrame with 'FullText' column to embed
        batch_size: Number of texts to embed in each API call

    Returns:
        Tuple of (faiss_index, embeddings_array)
    """
    client = create_client()
    print(f"Generating embeddings for {len(df)} products...")
    all_embeddings = []
    texts = df["FullText"].tolist()

    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        print(
            f"Processing batch {i // batch_size + 1}/{(len(texts) + batch_size - 1) // batch_size}"
        )

        response = client.embeddings.create(model=EMBEDDING_MODEL, input=batch)

        batch_embeddings = [item.embedding for item in response.data]
        all_embeddings.extend(batch_embeddings)

    embeddings_array = np.array(all_embeddings, dtype=np.float32)

    dimension = embeddings_array.shape[1]
    index = faiss.IndexFlatL2(
        dimension
    )  # L2 distance for cosine similarity after normalization

    # Normalize embeddings for cosine similarity
    faiss.normalize_L2(embeddings_array)

    index.add(embeddings_array)

    print(f"Created FAISS index with {index.ntotal} vectors of dimension {dimension}")

    faiss.write_index(index, str(_FILE_PATH / "data" / "faiss_index.bin"))
    np.save(_FILE_PATH / "data" / "embeddings.npy", embeddings_array)

    print(f"Saved FAISS index to {_FILE_PATH / 'data' / 'faiss_index.bin'}")
    print(f"Saved embeddings to {_FILE_PATH / 'data' / 'embeddings.npy'}")

    return index, embeddings_array


def load_faiss_index():
    """
    Load pre-computed FAISS index and embeddings from disk.

    Returns:
        Tuple of (faiss_index, embeddings_array)
    """
    index = faiss.read_index(str(_FILE_PATH / "data" / "faiss_index.bin"))
    embeddings = np.load(_FILE_PATH / "data" / "embeddings.npy")
    print(f"Loaded FAISS index with {index.ntotal} vectors")
    return index, embeddings


def create_bm25_index(df: pd.DataFrame):
    """
    Create BM25 index from product data for lexical search.

    Args:
        df: DataFrame with 'FullText' column to index

    Returns:
        BM25 index object
    """
    print(f"Creating BM25 index for {len(df)} products...")

    corpus = df["FullText"].tolist()
    corpus_tokens = bm25s.tokenize(corpus, stopwords="en")

    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)

    retriever.save(_FILE_PATH / "data" / "bm25_index")

    print(f"Saved BM25 index to {_FILE_PATH / 'data' / 'bm25_index'}")
    return retriever


def load_bm25_index():
    """
    Load pre-computed BM25 index from disk.

    Returns:
        BM25 index object
    """
    retriever = bm25s.BM25.load(_FILE_PATH / "data" / "bm25_index", load_corpus=False)
    print("Loaded BM25 index")
    return retriever


if __name__ == "__main__":
    _df = load_amazon_raw_product_data()
    _df = prepare_amazon_product_data(_df)
    save_as_parquet(_df)

    create_bm25_index(_df)
    create_faiss_index(_df)