File size: 12,910 Bytes
6510698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
"""
Integration tests for full application workflow.

These tests verify that all components work together correctly
in both CPU and GPU environments.
"""

import pytest
import os
import sys
import tempfile
from pathlib import Path
from unittest.mock import patch, MagicMock

# Add project root to path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))

from app.model_loader import ModelLoader
from app.interface import GradioInterface
from app.config.model_config import EnvironmentDetector, DependencyValidator


class TestCPUWorkflow:
    """Test complete workflow in CPU environment."""
    
    @patch('torch.cuda.is_available')
    @patch('app.config.model_config.DependencyValidator.is_environment_ready')
    def test_cpu_environment_detection(self, mock_env_ready, mock_cuda):
        """Test CPU environment is detected correctly."""
        mock_cuda.return_value = False
        mock_env_ready.return_value = True
        
        config = EnvironmentDetector.create_model_config()
        
        assert config.device_map == "cpu"
        assert config.dtype.name == "float32"  # torch.float32
        assert config.attn_implementation == "eager"
        assert config.low_cpu_mem_usage is True
    
    @patch('torch.cuda.is_available')
    @patch('app.model_loader.DependencyValidator.is_environment_ready')
    @patch('transformers.AutoTokenizer.from_pretrained')
    @patch('transformers.AutoModelForCausalLM.from_pretrained')
    @patch('transformers.pipeline')
    def test_cpu_model_loading_workflow(self, mock_pipeline, mock_model, 
                                       mock_tokenizer, mock_env_ready, mock_cuda):
        """Test complete model loading workflow on CPU."""
        # Setup mocks
        mock_cuda.return_value = False
        mock_env_ready.return_value = True
        
        mock_tokenizer_instance = MagicMock()
        mock_tokenizer.return_value = mock_tokenizer_instance
        
        mock_model_instance = MagicMock()
        mock_model_instance.eval.return_value = mock_model_instance
        mock_model_instance.num_parameters.return_value = 41900000000  # 41.9B
        mock_model.return_value = mock_model_instance
        
        mock_pipeline_instance = MagicMock()
        mock_pipeline.return_value = mock_pipeline_instance
        
        # Mock successful pipeline test
        mock_pipeline_instance.return_value = [{"generated_text": "Hello world"}]
        
        # Test the workflow
        loader = ModelLoader()
        success = loader.load_complete_model()
        
        assert success is True
        assert loader.is_loaded is True
        
        # Verify CPU-specific parameters were used
        model_call_args = mock_model.call_args
        assert model_call_args[1]['device_map'] == "cpu"
        assert model_call_args[1]['dtype'].name == "float32"  # torch.float32
        assert model_call_args[1]['attn_implementation'] == "eager"
        assert model_call_args[1]['low_cpu_mem_usage'] is True
    
    @patch('torch.cuda.is_available')
    def test_cpu_interface_creation(self, mock_cuda):
        """Test Gradio interface creation in CPU environment."""
        mock_cuda.return_value = False
        
        # Create mock model loader
        mock_loader = MagicMock(spec=ModelLoader)
        mock_loader.is_loaded = False
        mock_loader.get_model_info.return_value = {"status": "not_loaded"}
        
        # Create interface
        interface = GradioInterface(mock_loader)
        
        # This should not raise an exception
        with patch('gradio.Blocks') as mock_blocks:
            demo = interface.create_interface()
            mock_blocks.assert_called_once()


class TestGPUWorkflow:
    """Test complete workflow in GPU environment."""
    
    @patch('torch.cuda.is_available')
    @patch('app.config.model_config.DependencyValidator.is_environment_ready')
    def test_gpu_environment_detection(self, mock_env_ready, mock_cuda):
        """Test GPU environment is detected correctly."""
        mock_cuda.return_value = True
        mock_env_ready.return_value = True
        
        config = EnvironmentDetector.create_model_config()
        
        assert config.device_map == "auto"
        assert config.dtype.name == "bfloat16"  # torch.bfloat16
        assert config.attn_implementation == "sdpa"
        assert config.low_cpu_mem_usage is False
    
    @patch('torch.cuda.is_available')
    @patch('app.model_loader.DependencyValidator.is_environment_ready')
    @patch('transformers.AutoTokenizer.from_pretrained')
    @patch('transformers.AutoModelForCausalLM.from_pretrained')
    @patch('transformers.pipeline')
    def test_gpu_model_loading_workflow(self, mock_pipeline, mock_model, 
                                       mock_tokenizer, mock_env_ready, mock_cuda):
        """Test complete model loading workflow on GPU."""
        # Setup mocks
        mock_cuda.return_value = True
        mock_env_ready.return_value = True
        
        mock_tokenizer_instance = MagicMock()
        mock_tokenizer.return_value = mock_tokenizer_instance
        
        mock_model_instance = MagicMock()
        mock_model_instance.eval.return_value = mock_model_instance
        mock_model_instance.num_parameters.return_value = 41900000000  # 41.9B
        mock_model.return_value = mock_model_instance
        
        mock_pipeline_instance = MagicMock()
        mock_pipeline.return_value = mock_pipeline_instance
        
        # Mock successful pipeline test
        mock_pipeline_instance.return_value = [{"generated_text": "Hello world"}]
        
        # Test the workflow
        loader = ModelLoader()
        success = loader.load_complete_model()
        
        assert success is True
        assert loader.is_loaded is True
        
        # Verify GPU-specific parameters were used
        model_call_args = mock_model.call_args
        assert model_call_args[1]['device_map'] == "auto"
        assert model_call_args[1]['dtype'].name == "bfloat16"  # torch.bfloat16
        assert model_call_args[1]['attn_implementation'] == "sdpa"
        assert model_call_args[1]['low_cpu_mem_usage'] is False


class TestEnvironmentVariableWorkflow:
    """Test workflow with environment variables."""
    
    @patch.dict(os.environ, {
        'HF_MODEL_ID': 'custom/test-model',
        'HF_REVISION': 'test-revision-123'
    })
    @patch('torch.cuda.is_available')
    @patch('app.model_loader.DependencyValidator.is_environment_ready')
    @patch('transformers.AutoTokenizer.from_pretrained')
    @patch('transformers.AutoModelForCausalLM.from_pretrained')
    def test_environment_variables_respected(self, mock_model, mock_tokenizer, 
                                           mock_env_ready, mock_cuda):
        """Test that environment variables are properly used."""
        mock_cuda.return_value = False
        mock_env_ready.return_value = True
        
        mock_tokenizer.return_value = MagicMock()
        mock_model_instance = MagicMock()
        mock_model_instance.eval.return_value = mock_model_instance
        mock_model.return_value = mock_model_instance
        
        loader = ModelLoader()
        loader.create_config()
        
        # Verify environment variables were used
        assert loader.config.model_id == "custom/test-model"
        assert loader.config.revision == "test-revision-123"
        
        # Try loading tokenizer
        loader.load_tokenizer()
        
        # Verify tokenizer was called with env vars
        mock_tokenizer.assert_called_once_with(
            "custom/test-model",
            trust_remote_code=True,
            revision="test-revision-123"
        )


class TestErrorHandlingWorkflow:
    """Test error handling in complete workflow."""
    
    @patch('app.model_loader.DependencyValidator.is_environment_ready')
    def test_missing_dependencies_workflow(self, mock_env_ready):
        """Test workflow when dependencies are missing."""
        mock_env_ready.return_value = False
        
        loader = ModelLoader()
        success = loader.load_complete_model()
        
        assert success is False
        assert loader.is_loaded is False
    
    @patch('app.model_loader.DependencyValidator.is_environment_ready')
    @patch('transformers.AutoTokenizer.from_pretrained')
    def test_tokenizer_loading_failure(self, mock_tokenizer, mock_env_ready):
        """Test workflow when tokenizer loading fails."""
        mock_env_ready.return_value = True
        mock_tokenizer.side_effect = Exception("Tokenizer loading failed")
        
        loader = ModelLoader()
        success = loader.load_complete_model()
        
        assert success is False
        assert loader.is_loaded is False
    
    @patch('app.model_loader.DependencyValidator.is_environment_ready')
    @patch('transformers.AutoTokenizer.from_pretrained')
    @patch('transformers.AutoModelForCausalLM.from_pretrained')
    def test_model_loading_failure(self, mock_model, mock_tokenizer, mock_env_ready):
        """Test workflow when model loading fails."""
        mock_env_ready.return_value = True
        mock_tokenizer.return_value = MagicMock()
        mock_model.side_effect = Exception("Model loading failed")
        
        loader = ModelLoader()
        success = loader.load_complete_model()
        
        assert success is False
        assert loader.is_loaded is False
    
    def test_interface_with_failed_model_loading(self):
        """Test interface creation when model loading fails."""
        # Create loader with failed loading
        loader = MagicMock(spec=ModelLoader)
        loader.is_loaded = False
        loader.get_model_info.return_value = {"status": "not_loaded"}
        
        # Create interface
        interface = GradioInterface(loader)
        
        # Test response generation in fallback mode
        response = interface.response_generator.generate_response("Test query")
        
        assert "model is currently unavailable" in response
        assert "Expert Type:" in response


class TestRevisionSelectionWorkflow:
    """Test revision selection workflow."""
    
    @patch('torch.cuda.is_available')
    @patch.dict(os.environ, {}, clear=True)
    @patch('scripts.select_revision.RevisionSelector.find_cpu_safe_revision')
    @patch('scripts.select_revision.RevisionSelector.save_revision_to_env')
    def test_cpu_revision_selection_workflow(self, mock_save, mock_find, mock_cuda):
        """Test CPU revision selection workflow."""
        mock_cuda.return_value = False
        mock_find.return_value = "safe-revision-123"
        
        from scripts.select_revision import main
        
        result = main()
        
        assert result == 0
        mock_find.assert_called_once()
        mock_save.assert_called_once_with("safe-revision-123")
    
    @patch('torch.cuda.is_available')
    def test_gpu_skips_revision_selection(self, mock_cuda):
        """Test that GPU environment skips revision selection."""
        mock_cuda.return_value = True
        
        from scripts.select_revision import main
        
        result = main()
        
        assert result == 0  # Should exit early with success
    
    @patch('torch.cuda.is_available')
    @patch.dict(os.environ, {'HF_REVISION': 'existing-revision'})
    def test_existing_revision_skips_selection(self, mock_cuda):
        """Test that existing revision skips selection."""
        mock_cuda.return_value = False  # CPU environment
        
        from scripts.select_revision import main
        
        result = main()
        
        assert result == 0  # Should exit early with success


class TestPreStartWorkflow:
    """Test prestart script workflow."""
    
    def test_prestart_script_exists_and_executable(self):
        """Test that prestart script exists and is executable."""
        prestart_path = project_root / "prestart.sh"
        
        assert prestart_path.exists(), "prestart.sh should exist"
        assert os.access(prestart_path, os.X_OK), "prestart.sh should be executable"
    
    @patch('subprocess.check_call')
    @patch('torch.cuda.is_available')
    def test_prestart_cpu_workflow(self, mock_cuda, mock_subprocess):
        """Test prestart workflow on CPU."""
        mock_cuda.return_value = False
        
        # This would normally be tested by running the actual script
        # but for unit testing we verify the logic components
        
        # The prestart script should:
        # 1. Install core dependencies
        # 2. Skip flash-attn installation
        # 3. Run revision selector
        
        # We can't easily test the bash script directly in pytest,
        # but we can verify the Python components it calls work correctly
        assert True  # Placeholder for actual script testing


if __name__ == "__main__":
    pytest.main([__file__])