davidtran999 commited on
Commit
b930e6c
·
verified ·
1 Parent(s): 8883a13

Upload backend/hue_portal/core/tests/test_query_rewriter.py with huggingface_hub

Browse files
backend/hue_portal/core/tests/test_query_rewriter.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for Query Rewriter.
3
+ """
4
+ import unittest
5
+ from unittest.mock import Mock, patch
6
+ from hue_portal.core.query_rewriter import QueryRewriter, get_query_rewriter
7
+
8
+
9
+ class TestQueryRewriter(unittest.TestCase):
10
+ """Test QueryRewriter class."""
11
+
12
+ def setUp(self):
13
+ """Set up test fixtures."""
14
+ self.llm_generator = Mock()
15
+ self.llm_generator.is_available.return_value = True
16
+ self.llm_generator._generate_from_prompt.return_value = '{"queries": ["nội dung điều 12", "quy định điều 12", "điều 12 quy định về"]}'
17
+ self.llm_generator._extract_json_payload.return_value = {
18
+ "queries": ["nội dung điều 12", "quy định điều 12", "điều 12 quy định về"]
19
+ }
20
+ self.rewriter = QueryRewriter(llm_generator=self.llm_generator)
21
+
22
+ def test_rewrite_query_with_llm(self):
23
+ """Test query rewriting with LLM."""
24
+ queries = self.rewriter.rewrite_query("điều 12 nói gì")
25
+
26
+ self.assertIsInstance(queries, list)
27
+ self.assertGreaterEqual(len(queries), 3)
28
+ self.assertLessEqual(len(queries), 5)
29
+ self.assertTrue(all(isinstance(q, str) for q in queries))
30
+
31
+ # Verify LLM was called
32
+ self.llm_generator._generate_from_prompt.assert_called_once()
33
+
34
+ def test_rewrite_query_fallback(self):
35
+ """Test query rewriting fallback when LLM is not available."""
36
+ self.llm_generator.is_available.return_value = False
37
+ rewriter = QueryRewriter(llm_generator=self.llm_generator)
38
+
39
+ queries = rewriter.rewrite_query("điều 12 nói gì")
40
+
41
+ self.assertIsInstance(queries, list)
42
+ self.assertGreaterEqual(len(queries), 3)
43
+ self.assertLessEqual(len(queries), 5)
44
+ # Should include original query
45
+ self.assertIn("điều 12 nói gì", queries)
46
+
47
+ def test_rewrite_query_empty(self):
48
+ """Test query rewriting with empty query."""
49
+ queries = self.rewriter.rewrite_query("")
50
+ self.assertEqual(queries, [])
51
+
52
+ queries = self.rewriter.rewrite_query(" ")
53
+ self.assertEqual(queries, [])
54
+
55
+ def test_rewrite_query_with_context(self):
56
+ """Test query rewriting with conversation context."""
57
+ context = [
58
+ {"role": "user", "content": "Tôi muốn hỏi về kỷ luật"},
59
+ {"role": "bot", "content": "Bạn muốn hỏi về vấn đề gì?"},
60
+ ]
61
+
62
+ queries = self.rewriter.rewrite_query("điều 12", context=context)
63
+
64
+ self.assertIsInstance(queries, list)
65
+ self.assertGreaterEqual(len(queries), 3)
66
+ # Verify context was passed to LLM
67
+ call_args = self.llm_generator._generate_from_prompt.call_args[0][0]
68
+ self.assertIn("điều 12", call_args)
69
+
70
+ def test_get_cache_key(self):
71
+ """Test cache key generation."""
72
+ key1 = self.rewriter.get_cache_key("điều 12 nói gì")
73
+ key2 = self.rewriter.get_cache_key("điều 12 nói gì")
74
+ key3 = self.rewriter.get_cache_key("điều 13 nói gì")
75
+
76
+ # Same query should generate same key
77
+ self.assertEqual(key1, key2)
78
+ # Different query should generate different key
79
+ self.assertNotEqual(key1, key3)
80
+
81
+ def test_get_cache_key_with_context(self):
82
+ """Test cache key generation with context."""
83
+ context = [{"role": "user", "content": "test"}]
84
+ key1 = self.rewriter.get_cache_key("điều 12", context=context)
85
+ key2 = self.rewriter.get_cache_key("điều 12", context=context)
86
+ key3 = self.rewriter.get_cache_key("điều 12", context=None)
87
+
88
+ # Same query + context should generate same key
89
+ self.assertEqual(key1, key2)
90
+ # Different context should generate different key
91
+ self.assertNotEqual(key1, key3)
92
+
93
+ def test_fallback_patterns(self):
94
+ """Test fallback rewrite patterns."""
95
+ self.llm_generator.is_available.return_value = False
96
+ rewriter = QueryRewriter(llm_generator=self.llm_generator)
97
+
98
+ # Test "điều" pattern
99
+ queries = rewriter.rewrite_query("điều 12")
100
+ self.assertGreater(len(queries), 1)
101
+
102
+ # Test "phạt" pattern
103
+ queries = rewriter.rewrite_query("mức phạt vi phạm")
104
+ self.assertGreater(len(queries), 1)
105
+ self.assertTrue(any("phạt" in q.lower() for q in queries))
106
+
107
+ def test_get_query_rewriter(self):
108
+ """Test get_query_rewriter function."""
109
+ rewriter = get_query_rewriter()
110
+ self.assertIsInstance(rewriter, QueryRewriter)
111
+
112
+ rewriter2 = get_query_rewriter(self.llm_generator)
113
+ self.assertIsInstance(rewriter2, QueryRewriter)
114
+
115
+
116
+ if __name__ == "__main__":
117
+ unittest.main()
118
+