davidtran999 commited on
Commit
b94f6bc
·
verified ·
1 Parent(s): 05069e2

Upload backend/hue_portal/core/search_ml.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. backend/hue_portal/core/search_ml.py +336 -0
backend/hue_portal/core/search_ml.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Machine Learning-based search utilities using TF-IDF and text similarity.
3
+ """
4
+ import re
5
+ from typing import List, Tuple, Dict, Any
6
+ from sklearn.feature_extraction.text import TfidfVectorizer
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ import numpy as np
9
+ from django.db import connection
10
+ from django.db.models import Q, QuerySet, F
11
+ from django.contrib.postgres.search import SearchQuery, SearchRank
12
+ from .models import Synonym
13
+
14
+
15
+ def normalize_text(text: str) -> str:
16
+ """Normalize Vietnamese text for search."""
17
+ if not text:
18
+ return ""
19
+ # Lowercase and remove extra spaces
20
+ text = text.lower().strip()
21
+ text = re.sub(r'\s+', ' ', text)
22
+ return text
23
+
24
+
25
+ def expand_query_with_synonyms(query: str) -> List[str]:
26
+ """Expand query using synonyms from database."""
27
+ query_normalized = normalize_text(query)
28
+ expanded = [query_normalized]
29
+
30
+ try:
31
+ # Limit to prevent too many expansions
32
+ max_expansions = 10
33
+ expansion_count = 0
34
+
35
+ # Get all synonyms (limit to prevent too many DB queries)
36
+ synonyms = Synonym.objects.all()[:100] # Limit to 100 synonyms
37
+ for synonym in synonyms:
38
+ if expansion_count >= max_expansions:
39
+ break
40
+
41
+ keyword = normalize_text(synonym.keyword)
42
+ alias = normalize_text(synonym.alias)
43
+
44
+ # If query contains keyword, add alias
45
+ if keyword and keyword in query_normalized:
46
+ new_query = query_normalized.replace(keyword, alias)
47
+ if new_query not in expanded:
48
+ expanded.append(new_query)
49
+ expansion_count += 1
50
+ if expansion_count >= max_expansions:
51
+ break
52
+
53
+ # If query contains alias, add keyword
54
+ if alias and alias in query_normalized:
55
+ new_query = query_normalized.replace(alias, keyword)
56
+ if new_query not in expanded:
57
+ expanded.append(new_query)
58
+ expansion_count += 1
59
+ if expansion_count >= max_expansions:
60
+ break
61
+ except Exception:
62
+ pass # If Synonym table doesn't exist yet
63
+
64
+ return list(set(expanded))[:10] # Remove duplicates and limit to 10 variants
65
+
66
+
67
+ def create_search_vector(text_fields: List[str]) -> str:
68
+ """Create a searchable text vector from multiple fields."""
69
+ return " ".join(str(field) for field in text_fields if field)
70
+
71
+
72
+ def calculate_similarity_scores(
73
+ query: str,
74
+ documents: List[str],
75
+ top_k: int = 20
76
+ ) -> List[Tuple[int, float]]:
77
+ """
78
+ Calculate cosine similarity scores between query and documents.
79
+ Returns list of (index, score) tuples sorted by score descending.
80
+ """
81
+ if not query or not documents:
82
+ return []
83
+
84
+ # Expand query with synonyms
85
+ expanded_queries = expand_query_with_synonyms(query)
86
+
87
+ # Combine all query variations
88
+ all_texts = expanded_queries + documents
89
+
90
+ try:
91
+ # Create TF-IDF vectorizer
92
+ vectorizer = TfidfVectorizer(
93
+ analyzer='word',
94
+ ngram_range=(1, 2), # Unigrams and bigrams
95
+ min_df=1,
96
+ max_df=0.95,
97
+ lowercase=True,
98
+ token_pattern=r'\b\w+\b'
99
+ )
100
+
101
+ # Fit and transform
102
+ tfidf_matrix = vectorizer.fit_transform(all_texts)
103
+
104
+ # Get query vector (average of expanded queries)
105
+ query_vectors = tfidf_matrix[:len(expanded_queries)]
106
+ query_vector = np.mean(query_vectors.toarray(), axis=0).reshape(1, -1)
107
+
108
+ # Get document vectors
109
+ doc_vectors = tfidf_matrix[len(expanded_queries):]
110
+
111
+ # Calculate similarities
112
+ similarities = cosine_similarity(query_vector, doc_vectors)[0]
113
+
114
+ # Get top k results with scores
115
+ top_indices = np.argsort(similarities)[::-1][:top_k]
116
+ results = [(int(idx), float(similarities[idx])) for idx in top_indices if similarities[idx] > 0.0]
117
+
118
+ return results
119
+ except Exception as e:
120
+ # Fallback to simple text matching if ML fails
121
+ query_lower = normalize_text(query)
122
+ results = []
123
+ for idx, doc in enumerate(documents):
124
+ doc_lower = normalize_text(doc)
125
+ if query_lower in doc_lower:
126
+ # Simple score based on position and length
127
+ score = 1.0 - (doc_lower.find(query_lower) / max(len(doc_lower), 1))
128
+ results.append((idx, score))
129
+ return sorted(results, key=lambda x: x[1], reverse=True)[:top_k]
130
+
131
+
132
+ def search_with_ml(
133
+ queryset: QuerySet,
134
+ query: str,
135
+ text_fields: List[str],
136
+ top_k: int = 20,
137
+ min_score: float = 0.1,
138
+ use_hybrid: bool = True
139
+ ) -> QuerySet:
140
+ """
141
+ Search queryset using ML-based similarity scoring.
142
+
143
+ Args:
144
+ queryset: Django QuerySet to search
145
+ query: Search query string
146
+ text_fields: List of field names to search in
147
+ top_k: Maximum number of results
148
+ min_score: Minimum similarity score threshold
149
+
150
+ Returns:
151
+ Filtered and ranked QuerySet
152
+ """
153
+ if not query:
154
+ return queryset[:top_k]
155
+
156
+ # Try hybrid search if enabled
157
+ if use_hybrid:
158
+ try:
159
+ from .hybrid_search import search_with_hybrid
160
+ from .config.hybrid_search_config import get_config
161
+
162
+ # Determine content type from model
163
+ model_name = queryset.model.__name__.lower()
164
+ content_type = None
165
+ if 'procedure' in model_name:
166
+ content_type = 'procedure'
167
+ elif 'fine' in model_name:
168
+ content_type = 'fine'
169
+ elif 'office' in model_name:
170
+ content_type = 'office'
171
+ elif 'advisory' in model_name:
172
+ content_type = 'advisory'
173
+ elif 'legalsection' in model_name:
174
+ content_type = 'legal'
175
+
176
+ config = get_config(content_type)
177
+ return search_with_hybrid(
178
+ queryset,
179
+ query,
180
+ text_fields,
181
+ top_k=top_k,
182
+ min_score=min_score,
183
+ use_hybrid=True,
184
+ bm25_weight=config.bm25_weight,
185
+ vector_weight=config.vector_weight
186
+ )
187
+ except Exception as e:
188
+ print(f"Hybrid search not available, using BM25/TF-IDF: {e}")
189
+
190
+ # Attempt PostgreSQL BM25 ranking first when available
191
+ if connection.vendor == "postgresql" and hasattr(queryset.model, "tsv_body"):
192
+ try:
193
+ import sys
194
+ # Increase recursion limit for query expansion
195
+ old_limit = sys.getrecursionlimit()
196
+ try:
197
+ sys.setrecursionlimit(3000) # Increase limit for query expansion
198
+ expanded_queries = expand_query_with_synonyms(query)
199
+ # Limit expanded queries to prevent too many variants
200
+ expanded_queries = expanded_queries[:5] # Max 5 variants
201
+
202
+ combined_query = None
203
+ for q_variant in expanded_queries:
204
+ variant_query = SearchQuery(q_variant, config="simple")
205
+ combined_query = variant_query if combined_query is None else combined_query | variant_query
206
+
207
+ if combined_query is not None:
208
+ ranked_qs = (
209
+ queryset
210
+ .annotate(rank=SearchRank(F("tsv_body"), combined_query))
211
+ .filter(rank__gt=0)
212
+ .order_by("-rank")
213
+ )
214
+ results = list(ranked_qs[:top_k])
215
+ if results:
216
+ for obj in results:
217
+ obj._ml_score = getattr(obj, "rank", 0.0)
218
+ return results
219
+ finally:
220
+ sys.setrecursionlimit(old_limit) # Restore original limit
221
+ except RecursionError as e:
222
+ # Fallback: use original query without expansion
223
+ try:
224
+ variant_query = SearchQuery(query, config="simple")
225
+ ranked_qs = (
226
+ queryset
227
+ .annotate(rank=SearchRank(F("tsv_body"), variant_query))
228
+ .filter(rank__gt=0)
229
+ .order_by("-rank")
230
+ )
231
+ results = list(ranked_qs[:top_k])
232
+ if results:
233
+ for obj in results:
234
+ obj._ml_score = getattr(obj, "rank", 0.0)
235
+ return results
236
+ except Exception:
237
+ pass
238
+ except Exception:
239
+ # Fall through to ML-based search if any error occurs (e.g. missing extensions)
240
+ pass
241
+
242
+ # Get all objects and create search vectors
243
+ all_objects = list(queryset)
244
+ if not all_objects:
245
+ return queryset.none()
246
+
247
+ # Create search vectors for each object
248
+ documents = []
249
+ for obj in all_objects:
250
+ field_values = [getattr(obj, field, "") for field in text_fields]
251
+ search_vector = create_search_vector(field_values)
252
+ documents.append(search_vector)
253
+
254
+ # Calculate similarity scores
255
+ try:
256
+ import sys
257
+ # Increase recursion limit for TF-IDF calculation
258
+ old_limit = sys.getrecursionlimit()
259
+ try:
260
+ sys.setrecursionlimit(3000) # Increase limit for TF-IDF
261
+ scored_indices = calculate_similarity_scores(query, documents, top_k=top_k)
262
+ finally:
263
+ sys.setrecursionlimit(old_limit) # Restore original limit
264
+
265
+ # Filter by minimum score and get object IDs
266
+ valid_indices = [idx for idx, score in scored_indices if score >= min_score]
267
+
268
+ # If ML search found results, use them
269
+ if valid_indices:
270
+ result_objects = [all_objects[idx] for idx in valid_indices]
271
+ result_ids = [obj.id for obj in result_objects]
272
+
273
+ if result_ids:
274
+ # Create a mapping of ID to order for sorting
275
+ id_to_order = {obj_id: idx for idx, obj_id in enumerate(result_ids)}
276
+
277
+ # Filter by IDs and sort by the order
278
+ filtered = queryset.filter(id__in=result_ids)
279
+
280
+ # Convert to list, sort by order, then convert back to queryset
281
+ result_list = list(filtered)
282
+ result_list.sort(key=lambda x: id_to_order.get(x.id, 999))
283
+
284
+ # Return limited results - create a new queryset from IDs in order
285
+ ordered_ids = [obj.id for obj in result_list[:top_k]]
286
+ if ordered_ids:
287
+ # Use Case/When for ordering in PostgreSQL
288
+ from django.db.models import Case, When, IntegerField
289
+ preserved = Case(*[When(pk=pk, then=pos) for pos, pk in enumerate(ordered_ids)], output_field=IntegerField())
290
+ return queryset.filter(id__in=ordered_ids).order_by(preserved)
291
+ except Exception as e:
292
+ # If ML search fails, fall back to simple search
293
+ pass
294
+
295
+ # Fallback to simple icontains search with exact match prioritization
296
+ query_lower = normalize_text(query)
297
+ query_words = query_lower.split()
298
+
299
+ # Extract key phrases (2-3 words) for better matching
300
+ key_phrases = []
301
+ for i in range(len(query_words) - 1):
302
+ phrase = " ".join(query_words[i:i+2])
303
+ if len(phrase) > 3:
304
+ key_phrases.append(phrase)
305
+ for i in range(len(query_words) - 2):
306
+ phrase = " ".join(query_words[i:i+3])
307
+ if len(phrase) > 5:
308
+ key_phrases.append(phrase)
309
+
310
+ # Try to find exact phrase matches first
311
+ exact_matches = []
312
+ primary_field = text_fields[0] if text_fields else None
313
+ if primary_field:
314
+ for phrase in key_phrases:
315
+ filter_kwargs = {f"{primary_field}__icontains": phrase}
316
+ matches = list(queryset.filter(**filter_kwargs)[:top_k])
317
+ exact_matches.extend(matches)
318
+
319
+ # If we found exact matches, prioritize them
320
+ if exact_matches:
321
+ # Remove duplicates while preserving order
322
+ seen = set()
323
+ unique_matches = []
324
+ for obj in exact_matches:
325
+ if obj.id not in seen:
326
+ seen.add(obj.id)
327
+ unique_matches.append(obj)
328
+ return unique_matches[:top_k]
329
+
330
+ # Fallback to simple icontains search
331
+ q_objects = Q()
332
+ for field in text_fields:
333
+ q_objects |= Q(**{f"{field}__icontains": query})
334
+ return queryset.filter(q_objects)[:top_k]
335
+
336
+