davidtran999 commited on
Commit
fadb000
·
verified ·
1 Parent(s): 0391f27

Upload backend/core/faiss_index.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. backend/core/faiss_index.py +242 -0
backend/core/faiss_index.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FAISS index management for fast vector similarity search.
3
+ """
4
+ import os
5
+ import pickle
6
+ from pathlib import Path
7
+ from typing import List, Optional, Tuple
8
+ import numpy as np
9
+
10
+ try:
11
+ import faiss
12
+ FAISS_AVAILABLE = True
13
+ except ImportError:
14
+ FAISS_AVAILABLE = False
15
+ faiss = None
16
+
17
+ from django.conf import settings
18
+
19
+
20
+ # Default index directory
21
+ INDEX_DIR = Path(settings.BASE_DIR) / "artifacts" / "faiss_indexes"
22
+ INDEX_DIR.mkdir(parents=True, exist_ok=True)
23
+
24
+
25
+ class FAISSIndex:
26
+ """FAISS index wrapper for vector similarity search."""
27
+
28
+ def __init__(self, dimension: int, index_type: str = "IVF"):
29
+ """
30
+ Initialize FAISS index.
31
+
32
+ Args:
33
+ dimension: Embedding dimension.
34
+ index_type: Type of index ('IVF', 'HNSW', 'Flat').
35
+ """
36
+ if not FAISS_AVAILABLE:
37
+ raise ImportError("FAISS not available. Install with: pip install faiss-cpu")
38
+
39
+ self.dimension = dimension
40
+ self.index_type = index_type
41
+ self.index = None
42
+ self.id_to_index = {} # Map object ID to FAISS index
43
+ self.index_to_id = {} # Reverse mapping
44
+ self._build_index()
45
+
46
+ def _build_index(self):
47
+ """Build FAISS index based on type."""
48
+ if self.index_type == "Flat":
49
+ # Brute-force exact search
50
+ self.index = faiss.IndexFlatL2(self.dimension)
51
+ elif self.index_type == "IVF":
52
+ # Inverted file index (approximate, faster)
53
+ nlist = 100 # Number of clusters
54
+ quantizer = faiss.IndexFlatL2(self.dimension)
55
+ self.index = faiss.IndexIVFFlat(quantizer, self.dimension, nlist)
56
+ elif self.index_type == "HNSW":
57
+ # Hierarchical Navigable Small World (fast approximate)
58
+ M = 32 # Number of connections
59
+ self.index = faiss.IndexHNSWFlat(self.dimension, M)
60
+ else:
61
+ raise ValueError(f"Unknown index type: {self.index_type}")
62
+
63
+ def train(self, vectors: np.ndarray):
64
+ """Train index (required for IVF)."""
65
+ if hasattr(self.index, 'train') and not self.index.is_trained:
66
+ self.index.train(vectors)
67
+
68
+ def add(self, vectors: np.ndarray, ids: List[int]):
69
+ """
70
+ Add vectors to index.
71
+
72
+ Args:
73
+ vectors: Numpy array of shape (n, dimension).
74
+ ids: List of object IDs corresponding to vectors.
75
+ """
76
+ if len(vectors) == 0:
77
+ return
78
+
79
+ # Normalize vectors
80
+ faiss.normalize_L2(vectors)
81
+
82
+ # Train if needed (for IVF)
83
+ if hasattr(self.index, 'train') and not self.index.is_trained:
84
+ self.train(vectors)
85
+
86
+ # Get current index size
87
+ start_idx = len(self.id_to_index)
88
+
89
+ # Add to index
90
+ self.index.add(vectors)
91
+
92
+ # Update mappings
93
+ for i, obj_id in enumerate(ids):
94
+ faiss_idx = start_idx + i
95
+ self.id_to_index[obj_id] = faiss_idx
96
+ self.index_to_id[faiss_idx] = obj_id
97
+
98
+ def search(self, query_vector: np.ndarray, k: int = 10) -> List[Tuple[int, float]]:
99
+ """
100
+ Search for similar vectors.
101
+
102
+ Args:
103
+ query_vector: Query vector of shape (dimension,).
104
+ k: Number of results to return.
105
+
106
+ Returns:
107
+ List of (object_id, distance) tuples.
108
+ """
109
+ if self.index.ntotal == 0:
110
+ return []
111
+
112
+ # Normalize query
113
+ query_vector = query_vector.reshape(1, -1).astype('float32')
114
+ faiss.normalize_L2(query_vector)
115
+
116
+ # Search
117
+ distances, indices = self.index.search(query_vector, k)
118
+
119
+ # Convert to object IDs
120
+ results = []
121
+ for idx, dist in zip(indices[0], distances[0]):
122
+ if idx < 0: # Invalid index
123
+ continue
124
+ obj_id = self.index_to_id.get(idx)
125
+ if obj_id is not None:
126
+ # Convert L2 distance to similarity (1 - normalized distance)
127
+ similarity = 1.0 / (1.0 + float(dist))
128
+ results.append((obj_id, similarity))
129
+
130
+ return results
131
+
132
+ def save(self, filepath: Path):
133
+ """Save index to file."""
134
+ filepath.parent.mkdir(parents=True, exist_ok=True)
135
+
136
+ # Save FAISS index
137
+ faiss.write_index(self.index, str(filepath))
138
+
139
+ # Save mappings
140
+ mappings_file = filepath.with_suffix('.mappings.pkl')
141
+ with open(mappings_file, 'wb') as f:
142
+ pickle.dump({
143
+ 'id_to_index': self.id_to_index,
144
+ 'index_to_id': self.index_to_id,
145
+ 'dimension': self.dimension,
146
+ 'index_type': self.index_type
147
+ }, f)
148
+
149
+ @classmethod
150
+ def load(cls, filepath: Path) -> 'FAISSIndex':
151
+ """Load index from file."""
152
+ if not filepath.exists():
153
+ raise FileNotFoundError(f"Index file not found: {filepath}")
154
+
155
+ # Load FAISS index
156
+ index = faiss.read_index(str(filepath))
157
+
158
+ # Load mappings
159
+ mappings_file = filepath.with_suffix('.mappings.pkl')
160
+ with open(mappings_file, 'rb') as f:
161
+ mappings = pickle.load(f)
162
+
163
+ # Create instance
164
+ instance = cls.__new__(cls)
165
+ instance.index = index
166
+ instance.id_to_index = mappings['id_to_index']
167
+ instance.index_to_id = mappings['index_to_id']
168
+ instance.dimension = mappings['dimension']
169
+ instance.index_type = mappings['index_type']
170
+
171
+ return instance
172
+
173
+
174
+ def build_faiss_index_for_model(model_class, model_name: str, index_type: str = "IVF") -> Optional[FAISSIndex]:
175
+ """
176
+ Build FAISS index for a Django model.
177
+
178
+ Args:
179
+ model_class: Django model class.
180
+ model_name: Name of model (for file naming).
181
+ index_type: Type of FAISS index.
182
+
183
+ Returns:
184
+ FAISSIndex instance or None if error.
185
+ """
186
+ if not FAISS_AVAILABLE:
187
+ print("FAISS not available. Skipping index build.")
188
+ return None
189
+
190
+ from hue_portal.core.embeddings import get_embedding_dimension
191
+ from hue_portal.core.embedding_utils import load_embedding
192
+
193
+ # Get embedding dimension
194
+ dim = get_embedding_dimension()
195
+ if dim == 0:
196
+ print("Cannot determine embedding dimension. Skipping index build.")
197
+ return None
198
+
199
+ # Get all instances with embeddings first to determine count
200
+ instances = list(model_class.objects.exclude(embedding__isnull=True))
201
+ if not instances:
202
+ print(f"No instances with embeddings found for {model_name}.")
203
+ return None
204
+
205
+ # Auto-adjust index type: IVF requires at least 100 vectors for training with 100 clusters
206
+ # If we have fewer vectors, use Flat index instead
207
+ if index_type == "IVF" and len(instances) < 100:
208
+ print(f"⚠️ Only {len(instances)} instances found. Switching from IVF to Flat index (IVF requires >= 100 vectors).")
209
+ index_type = "Flat"
210
+
211
+ # Create index
212
+ faiss_index = FAISSIndex(dimension=dim, index_type=index_type)
213
+
214
+ print(f"Building FAISS index for {model_name} ({len(instances)} instances, type: {index_type})...")
215
+
216
+ # Collect vectors and IDs
217
+ vectors = []
218
+ ids = []
219
+
220
+ for instance in instances:
221
+ embedding = load_embedding(instance)
222
+ if embedding is not None:
223
+ vectors.append(embedding)
224
+ ids.append(instance.id)
225
+
226
+ if not vectors:
227
+ print(f"No valid embeddings found for {model_name}.")
228
+ return None
229
+
230
+ # Convert to numpy array
231
+ vectors_array = np.array(vectors, dtype='float32')
232
+
233
+ # Add to index
234
+ faiss_index.add(vectors_array, ids)
235
+
236
+ # Save index
237
+ index_file = INDEX_DIR / f"{model_name.lower()}_{index_type.lower()}.faiss"
238
+ faiss_index.save(index_file)
239
+
240
+ print(f"✅ Built and saved FAISS index: {index_file}")
241
+ return faiss_index
242
+