QuoteSearch / offline_processing.py
ruidiao's picture
Upload 4 files
ea0ddf6
import pandas as pd
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
import json
import gzip
import struct
import re
# 1. Constants
MODEL_NAME = 'nomic-ai/nomic-embed-text-v1.5'
EMBEDDING_DIM = 256
INPUT_CSV_PATH = 'data/quotes.csv'
OUTPUT_BINARY_PATH = 'data/quotes_index.bin'
BATCH_SIZE = 64 # Optimized batch size for GPU processing
# 2. Load Data
def load_quotes(file_path):
"""Loads quotes from a CSV file using pandas, skipping malformed lines."""
# Define column names explicitly
column_names = ['quote', 'author', 'category']
# Read CSV with explicit separator, quote character, and skip bad lines
# header=None because we are providing names explicitly
df = pd.read_csv(
file_path,
sep=',',
quotechar='"',
header=None,
names=column_names,
on_bad_lines='skip' # Skip lines that pandas cannot parse into 3 columns
)
# Filter out rows where the category contains uppercase letters
initial_rows = len(df)
df = df[df['category'].apply(lambda x: isinstance(x, str) and not any(c.isupper() for c in x))]
filtered_rows = len(df)
if initial_rows - filtered_rows > 0:
print(f"Ignored {initial_rows - filtered_rows} rows due to uppercase letters in category.")
# Ensure author is a string for grouping (empty string for missing authors)
df['author'] = df['author'].fillna('').astype(str)
# Group by quote and author to deduplicate entries.
# We intentionally ignore categories to reduce metadata size in the output index.
grouped = {}
for _, row in df.iterrows():
quote = row['quote']
author = row['author']
# Build a case-insensitive key for deduplication
quote_key = quote.lower().strip() if isinstance(quote, str) else ''
author_key = author.lower().strip() if isinstance(author, str) else ''
key = (quote_key, author_key)
if key not in grouped:
grouped[key] = { 'quote': quote, 'author': author }
# Build records from grouped data; do NOT include categories
records = []
for key, data in grouped.items():
orig_author = data['author'] if data['author'] != '' else None
records.append({
'quote': data['quote'],
'author': orig_author
})
# Prepend the required prefix for retrieval tasks
for r in records:
r['quote_for_embedding'] = "search_query: " + r['quote']
return records
# 3. Generate Embeddings
def generate_embeddings(quotes, model_name, embedding_dim):
"""Generates and truncates embeddings for a list of quotes."""
model = SentenceTransformer(model_name, trust_remote_code=True)
# The model automatically uses the GPU if available
embeddings = model.encode(
[q['quote_for_embedding'] for q in quotes],
convert_to_tensor=True,
batch_size=BATCH_SIZE,
show_progress_bar=True # Display progress bar for embedding generation
)
# Truncate embeddings to the desired dimension
truncated_embeddings = embeddings[:, :embedding_dim]
return truncated_embeddings.cpu().numpy()
# 4. Quantize Embeddings
def quantize_embeddings(embeddings):
"""Quantizes float32 embeddings to int8."""
# Calculate the scale factor
abs_max = np.abs(embeddings).max()
scale = 127.0 / abs_max if abs_max != 0 else 0
# Quantize and clip
quantized_embeddings = np.clip(embeddings * scale, -127, 127).astype(np.int8)
return quantized_embeddings, scale
def main():
"""Main function to run the offline processing pipeline."""
print("Starting offline processing...")
# Load quotes
quotes = load_quotes(INPUT_CSV_PATH)
print(f"Loaded {len(quotes)} quotes.")
# Generate embeddings
print("Generating embeddings...")
float_embeddings = generate_embeddings(quotes, MODEL_NAME, EMBEDDING_DIM)
print(f"Generated float embeddings with shape: {float_embeddings.shape}")
# Quantize embeddings
print("Quantizing embeddings...")
quantized_embeddings, scale = quantize_embeddings(float_embeddings)
print(f"Quantized embeddings with shape: {quantized_embeddings.shape}")
print(f"Quantization scale factor: {scale}")
# Prepare metadata without categories to reduce index size
metadata = [
{"quote": q["quote"], "author": q["author"]}
for q in quotes
]
# Replace NaN values with None for JSON compatibility
for item in metadata:
for key, value in list(item.items()):
if isinstance(value, float) and np.isnan(value):
item[key] = None
# After cleaning metadata, serialize and compress once
metadata_json = json.dumps(metadata, separators=(",", ":"))
metadata_bytes_uncompressed = metadata_json.encode('utf-8')
# Compress metadata with gzip to reduce index size on disk
metadata_bytes = gzip.compress(metadata_bytes_uncompressed)
# metadata format: 0 = uncompressed JSON (legacy), 1 = gzip-compressed JSON
metadata_format = 1
# Pack data into a binary file
print("Packaging data into binary file...")
with open(OUTPUT_BINARY_PATH, 'wb') as f:
# Header
f.write(struct.pack('<I', len(quotes))) # 4 bytes: number of quotes
f.write(struct.pack('<H', EMBEDDING_DIM)) # 2 bytes: embedding dimension
f.write(struct.pack('<f', scale)) # 4 bytes: quantization scale factor
f.write(struct.pack('<I', len(metadata_bytes))) # 4 bytes: metadata size
f.write(struct.pack('<B', metadata_format)) # 1 byte: metadata format flag
# Metadata (possibly compressed)
f.write(metadata_bytes)
# Embeddings
f.write(quantized_embeddings.tobytes())
print(f"Offline processing complete. Index file saved to {OUTPUT_BINARY_PATH}")
if __name__ == "__main__":
main()