davidtran999 commited on
Commit
c696533
·
verified ·
1 Parent(s): 2fafc21

Upload backend/core/search_ml.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. backend/core/search_ml.py +284 -0
backend/core/search_ml.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Machine Learning-based search utilities using TF-IDF and text similarity.
3
+ """
4
+ import re
5
+ from typing import List, Tuple, Dict, Any
6
+ from sklearn.feature_extraction.text import TfidfVectorizer
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ import numpy as np
9
+ from django.db import connection
10
+ from django.db.models import Q, QuerySet, F
11
+ from django.contrib.postgres.search import SearchQuery, SearchRank
12
+ from .models import Synonym
13
+
14
+
15
+ def normalize_text(text: str) -> str:
16
+ """Normalize Vietnamese text for search."""
17
+ if not text:
18
+ return ""
19
+ # Lowercase and remove extra spaces
20
+ text = text.lower().strip()
21
+ text = re.sub(r'\s+', ' ', text)
22
+ return text
23
+
24
+
25
+ def expand_query_with_synonyms(query: str) -> List[str]:
26
+ """Expand query using synonyms from database."""
27
+ query_normalized = normalize_text(query)
28
+ expanded = [query_normalized]
29
+
30
+ try:
31
+ # Get all synonyms
32
+ synonyms = Synonym.objects.all()
33
+ for synonym in synonyms:
34
+ keyword = normalize_text(synonym.keyword)
35
+ alias = normalize_text(synonym.alias)
36
+
37
+ # If query contains keyword, add alias
38
+ if keyword in query_normalized:
39
+ expanded.append(query_normalized.replace(keyword, alias))
40
+ # If query contains alias, add keyword
41
+ if alias in query_normalized:
42
+ expanded.append(query_normalized.replace(alias, keyword))
43
+ except Exception:
44
+ pass # If Synonym table doesn't exist yet
45
+
46
+ return list(set(expanded)) # Remove duplicates
47
+
48
+
49
+ def create_search_vector(text_fields: List[str]) -> str:
50
+ """Create a searchable text vector from multiple fields."""
51
+ return " ".join(str(field) for field in text_fields if field)
52
+
53
+
54
+ def calculate_similarity_scores(
55
+ query: str,
56
+ documents: List[str],
57
+ top_k: int = 20
58
+ ) -> List[Tuple[int, float]]:
59
+ """
60
+ Calculate cosine similarity scores between query and documents.
61
+ Returns list of (index, score) tuples sorted by score descending.
62
+ """
63
+ if not query or not documents:
64
+ return []
65
+
66
+ # Expand query with synonyms
67
+ expanded_queries = expand_query_with_synonyms(query)
68
+
69
+ # Combine all query variations
70
+ all_texts = expanded_queries + documents
71
+
72
+ try:
73
+ # Create TF-IDF vectorizer
74
+ vectorizer = TfidfVectorizer(
75
+ analyzer='word',
76
+ ngram_range=(1, 2), # Unigrams and bigrams
77
+ min_df=1,
78
+ max_df=0.95,
79
+ lowercase=True,
80
+ token_pattern=r'\b\w+\b'
81
+ )
82
+
83
+ # Fit and transform
84
+ tfidf_matrix = vectorizer.fit_transform(all_texts)
85
+
86
+ # Get query vector (average of expanded queries)
87
+ query_vectors = tfidf_matrix[:len(expanded_queries)]
88
+ query_vector = np.mean(query_vectors.toarray(), axis=0).reshape(1, -1)
89
+
90
+ # Get document vectors
91
+ doc_vectors = tfidf_matrix[len(expanded_queries):]
92
+
93
+ # Calculate similarities
94
+ similarities = cosine_similarity(query_vector, doc_vectors)[0]
95
+
96
+ # Get top k results with scores
97
+ top_indices = np.argsort(similarities)[::-1][:top_k]
98
+ results = [(int(idx), float(similarities[idx])) for idx in top_indices if similarities[idx] > 0.0]
99
+
100
+ return results
101
+ except Exception as e:
102
+ # Fallback to simple text matching if ML fails
103
+ query_lower = normalize_text(query)
104
+ results = []
105
+ for idx, doc in enumerate(documents):
106
+ doc_lower = normalize_text(doc)
107
+ if query_lower in doc_lower:
108
+ # Simple score based on position and length
109
+ score = 1.0 - (doc_lower.find(query_lower) / max(len(doc_lower), 1))
110
+ results.append((idx, score))
111
+ return sorted(results, key=lambda x: x[1], reverse=True)[:top_k]
112
+
113
+
114
+ def search_with_ml(
115
+ queryset: QuerySet,
116
+ query: str,
117
+ text_fields: List[str],
118
+ top_k: int = 20,
119
+ min_score: float = 0.1,
120
+ use_hybrid: bool = True
121
+ ) -> QuerySet:
122
+ """
123
+ Search queryset using ML-based similarity scoring.
124
+
125
+ Args:
126
+ queryset: Django QuerySet to search
127
+ query: Search query string
128
+ text_fields: List of field names to search in
129
+ top_k: Maximum number of results
130
+ min_score: Minimum similarity score threshold
131
+
132
+ Returns:
133
+ Filtered and ranked QuerySet
134
+ """
135
+ if not query:
136
+ return queryset[:top_k]
137
+
138
+ # Try hybrid search if enabled
139
+ if use_hybrid:
140
+ try:
141
+ from .hybrid_search import search_with_hybrid
142
+ from .config.hybrid_search_config import get_config
143
+
144
+ # Determine content type from model
145
+ model_name = queryset.model.__name__.lower()
146
+ content_type = None
147
+ if 'procedure' in model_name:
148
+ content_type = 'procedure'
149
+ elif 'fine' in model_name:
150
+ content_type = 'fine'
151
+ elif 'office' in model_name:
152
+ content_type = 'office'
153
+ elif 'advisory' in model_name:
154
+ content_type = 'advisory'
155
+ elif 'legalsection' in model_name:
156
+ content_type = 'legal'
157
+
158
+ config = get_config(content_type)
159
+ return search_with_hybrid(
160
+ queryset,
161
+ query,
162
+ text_fields,
163
+ top_k=top_k,
164
+ min_score=min_score,
165
+ use_hybrid=True,
166
+ bm25_weight=config.bm25_weight,
167
+ vector_weight=config.vector_weight
168
+ )
169
+ except Exception as e:
170
+ print(f"Hybrid search not available, using BM25/TF-IDF: {e}")
171
+
172
+ # Attempt PostgreSQL BM25 ranking first when available
173
+ if connection.vendor == "postgresql" and hasattr(queryset.model, "tsv_body"):
174
+ try:
175
+ expanded_queries = expand_query_with_synonyms(query)
176
+ combined_query = None
177
+ for q_variant in expanded_queries:
178
+ variant_query = SearchQuery(q_variant, config="simple")
179
+ combined_query = variant_query if combined_query is None else combined_query | variant_query
180
+
181
+ if combined_query is not None:
182
+ ranked_qs = (
183
+ queryset
184
+ .annotate(rank=SearchRank(F("tsv_body"), combined_query))
185
+ .filter(rank__gt=0)
186
+ .order_by("-rank")
187
+ )
188
+ results = list(ranked_qs[:top_k])
189
+ if results:
190
+ for obj in results:
191
+ obj._ml_score = getattr(obj, "rank", 0.0)
192
+ return results
193
+ except Exception:
194
+ # Fall through to ML-based search if any error occurs (e.g. missing extensions)
195
+ pass
196
+
197
+ # Get all objects and create search vectors
198
+ all_objects = list(queryset)
199
+ if not all_objects:
200
+ return queryset.none()
201
+
202
+ # Create search vectors for each object
203
+ documents = []
204
+ for obj in all_objects:
205
+ field_values = [getattr(obj, field, "") for field in text_fields]
206
+ search_vector = create_search_vector(field_values)
207
+ documents.append(search_vector)
208
+
209
+ # Calculate similarity scores
210
+ try:
211
+ scored_indices = calculate_similarity_scores(query, documents, top_k=top_k)
212
+
213
+ # Filter by minimum score and get object IDs
214
+ valid_indices = [idx for idx, score in scored_indices if score >= min_score]
215
+
216
+ # If ML search found results, use them
217
+ if valid_indices:
218
+ result_objects = [all_objects[idx] for idx in valid_indices]
219
+ result_ids = [obj.id for obj in result_objects]
220
+
221
+ if result_ids:
222
+ # Create a mapping of ID to order for sorting
223
+ id_to_order = {obj_id: idx for idx, obj_id in enumerate(result_ids)}
224
+
225
+ # Filter by IDs and sort by the order
226
+ filtered = queryset.filter(id__in=result_ids)
227
+
228
+ # Convert to list, sort by order, then convert back to queryset
229
+ result_list = list(filtered)
230
+ result_list.sort(key=lambda x: id_to_order.get(x.id, 999))
231
+
232
+ # Return limited results - create a new queryset from IDs in order
233
+ ordered_ids = [obj.id for obj in result_list[:top_k]]
234
+ if ordered_ids:
235
+ # Use Case/When for ordering in PostgreSQL
236
+ from django.db.models import Case, When, IntegerField
237
+ preserved = Case(*[When(pk=pk, then=pos) for pos, pk in enumerate(ordered_ids)], output_field=IntegerField())
238
+ return queryset.filter(id__in=ordered_ids).order_by(preserved)
239
+ except Exception as e:
240
+ # If ML search fails, fall back to simple search
241
+ pass
242
+
243
+ # Fallback to simple icontains search with exact match prioritization
244
+ query_lower = normalize_text(query)
245
+ query_words = query_lower.split()
246
+
247
+ # Extract key phrases (2-3 words) for better matching
248
+ key_phrases = []
249
+ for i in range(len(query_words) - 1):
250
+ phrase = " ".join(query_words[i:i+2])
251
+ if len(phrase) > 3:
252
+ key_phrases.append(phrase)
253
+ for i in range(len(query_words) - 2):
254
+ phrase = " ".join(query_words[i:i+3])
255
+ if len(phrase) > 5:
256
+ key_phrases.append(phrase)
257
+
258
+ # Try to find exact phrase matches first
259
+ exact_matches = []
260
+ primary_field = text_fields[0] if text_fields else None
261
+ if primary_field:
262
+ for phrase in key_phrases:
263
+ filter_kwargs = {f"{primary_field}__icontains": phrase}
264
+ matches = list(queryset.filter(**filter_kwargs)[:top_k])
265
+ exact_matches.extend(matches)
266
+
267
+ # If we found exact matches, prioritize them
268
+ if exact_matches:
269
+ # Remove duplicates while preserving order
270
+ seen = set()
271
+ unique_matches = []
272
+ for obj in exact_matches:
273
+ if obj.id not in seen:
274
+ seen.add(obj.id)
275
+ unique_matches.append(obj)
276
+ return unique_matches[:top_k]
277
+
278
+ # Fallback to simple icontains search
279
+ q_objects = Q()
280
+ for field in text_fields:
281
+ q_objects |= Q(**{f"{field}__icontains": query})
282
+ return queryset.filter(q_objects)[:top_k]
283
+
284
+