Spaces:
Sleeping
Sleeping
| """ | |
| Contract tests for Hugging Face API interactions. | |
| These tests verify that our code correctly interacts with external HF services. | |
| They can be run against real APIs or mocked for CI/CD. | |
| """ | |
| import pytest | |
| import os | |
| from unittest.mock import patch, MagicMock | |
| import sys | |
| from pathlib import Path | |
| # Add project root to path | |
| project_root = Path(__file__).parent.parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| from scripts.select_revision import RevisionSelector | |
| from huggingface_hub import HfApi | |
| from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError | |
| class TestHuggingFaceAPIContract: | |
| """Test contracts with Hugging Face API.""" | |
| def setup_method(self): | |
| """Setup test fixtures.""" | |
| self.test_model_id = "microsoft/Phi-3.5-MoE-instruct" | |
| self.selector = RevisionSelector(self.test_model_id) | |
| def test_hf_api_connection(self): | |
| """Test that we can connect to HF API (requires internet).""" | |
| api = HfApi() | |
| try: | |
| # Try to get model info - this should work for public models | |
| model_info = api.model_info(self.test_model_id) | |
| assert model_info is not None | |
| assert model_info.modelId == self.test_model_id | |
| except Exception as e: | |
| pytest.skip(f"Cannot connect to HF API: {e}") | |
| def test_get_recent_commits_contract(self, mock_list_commits): | |
| """Test contract for getting recent commits.""" | |
| # Mock commit objects | |
| mock_commits = [ | |
| MagicMock(commit_id="abc123"), | |
| MagicMock(commit_id="def456"), | |
| MagicMock(commit_id="ghi789") | |
| ] | |
| mock_list_commits.return_value = mock_commits | |
| commits = self.selector.get_recent_commits(max_commits=2) | |
| # Verify API was called correctly | |
| mock_list_commits.assert_called_once_with( | |
| repo_id=self.test_model_id, | |
| repo_type="model" | |
| ) | |
| # Verify we got the expected number of commits | |
| assert len(commits) == 2 | |
| assert commits == ["abc123", "def456"] | |
| def test_get_recent_commits_api_error(self, mock_list_commits): | |
| """Test handling of API errors when getting commits.""" | |
| mock_list_commits.side_effect = RepositoryNotFoundError("Model not found") | |
| commits = self.selector.get_recent_commits() | |
| # Should return empty list on error | |
| assert commits == [] | |
| def test_is_cpu_safe_revision_contract(self, mock_download): | |
| """Test contract for checking CPU-safe revisions.""" | |
| # Mock file content without flash_attn imports | |
| mock_file_path = "/tmp/test_modeling.py" | |
| mock_download.return_value = mock_file_path | |
| # Create mock file content | |
| safe_content = """ | |
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel | |
| class TestModel(PreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # No flash_attn imports here | |
| """ | |
| with patch('builtins.open', create=True) as mock_open: | |
| mock_open.return_value.__enter__.return_value.read.return_value = safe_content | |
| result = self.selector.is_cpu_safe_revision("abc123") | |
| # Verify download was called correctly | |
| mock_download.assert_called_once_with( | |
| repo_id=self.test_model_id, | |
| filename="modeling_phimoe.py", | |
| revision="abc123", | |
| repo_type="model", | |
| cache_dir=".cache" | |
| ) | |
| assert result is True | |
| def test_is_cpu_safe_revision_with_flash_attn(self, mock_download): | |
| """Test detection of flash_attn imports.""" | |
| mock_file_path = "/tmp/test_modeling.py" | |
| mock_download.return_value = mock_file_path | |
| # Mock file content WITH flash_attn imports | |
| unsafe_content = """ | |
| import torch | |
| import torch.nn as nn | |
| import flash_attn | |
| from transformers import PreTrainedModel | |
| class TestModel(PreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| """ | |
| with patch('builtins.open', create=True) as mock_open: | |
| mock_open.return_value.__enter__.return_value.read.return_value = unsafe_content | |
| result = self.selector.is_cpu_safe_revision("abc123") | |
| assert result is False | |
| def test_is_cpu_safe_revision_download_error(self, mock_download): | |
| """Test handling of download errors.""" | |
| mock_download.side_effect = RevisionNotFoundError("Revision not found") | |
| result = self.selector.is_cpu_safe_revision("nonexistent") | |
| # Should return False on download error | |
| assert result is False | |
| def test_save_revision_to_env_contract(self): | |
| """Test contract for saving revision to .env file.""" | |
| test_revision = "abc123def456" | |
| # Use a temporary file for testing | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(mode='w+', suffix='.env', delete=False) as tmp_file: | |
| tmp_path = Path(tmp_file.name) | |
| try: | |
| # Patch the ENV_FILE path | |
| with patch('scripts.select_revision.ENV_FILE', tmp_path): | |
| self.selector.save_revision_to_env(test_revision) | |
| # Verify file was written correctly | |
| content = tmp_path.read_text() | |
| assert f"HF_REVISION={test_revision}" in content | |
| finally: | |
| # Clean up | |
| if tmp_path.exists(): | |
| tmp_path.unlink() | |
| def test_save_revision_to_env_existing_file(self): | |
| """Test saving revision when .env file already exists.""" | |
| test_revision = "new123revision" | |
| existing_content = """ | |
| # Existing env file | |
| SOME_VAR=value | |
| HF_REVISION=old123revision | |
| OTHER_VAR=other_value | |
| """ | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(mode='w+', suffix='.env', delete=False) as tmp_file: | |
| tmp_file.write(existing_content) | |
| tmp_file.flush() | |
| tmp_path = Path(tmp_file.name) | |
| try: | |
| with patch('scripts.select_revision.ENV_FILE', tmp_path): | |
| self.selector.save_revision_to_env(test_revision) | |
| content = tmp_path.read_text() | |
| # Should have new revision | |
| assert f"HF_REVISION={test_revision}" in content | |
| # Should not have old revision | |
| assert "HF_REVISION=old123revision" not in content | |
| # Should preserve other variables | |
| assert "SOME_VAR=value" in content | |
| assert "OTHER_VAR=other_value" in content | |
| finally: | |
| if tmp_path.exists(): | |
| tmp_path.unlink() | |
| class TestTransformersContract: | |
| """Test contracts with transformers library.""" | |
| def test_tokenizer_loading_contract(self, mock_tokenizer): | |
| """Test contract for tokenizer loading.""" | |
| mock_tokenizer_instance = MagicMock() | |
| mock_tokenizer.return_value = mock_tokenizer_instance | |
| from app.model_loader import ModelLoader | |
| loader = ModelLoader() | |
| # Create a minimal config | |
| from app.config.model_config import ModelConfig | |
| import torch | |
| loader.config = ModelConfig( | |
| model_id="test/model", | |
| revision="main", | |
| dtype=torch.float32, | |
| device_map="cpu", | |
| attn_implementation="eager", | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| result = loader.load_tokenizer() | |
| # Verify tokenizer was called with correct parameters | |
| mock_tokenizer.assert_called_once_with( | |
| "test/model", | |
| trust_remote_code=True, | |
| revision="main" | |
| ) | |
| assert result is True | |
| assert loader.tokenizer == mock_tokenizer_instance | |
| def test_model_loading_contract(self, mock_model): | |
| """Test contract for model loading.""" | |
| mock_model_instance = MagicMock() | |
| mock_model_instance.eval.return_value = mock_model_instance | |
| mock_model.return_value = mock_model_instance | |
| from app.model_loader import ModelLoader | |
| loader = ModelLoader() | |
| # Create a minimal config | |
| from app.config.model_config import ModelConfig | |
| import torch | |
| loader.config = ModelConfig( | |
| model_id="test/model", | |
| revision="main", | |
| dtype=torch.float32, | |
| device_map="cpu", | |
| attn_implementation="eager", | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| result = loader.load_model() | |
| # Verify model was called with correct parameters | |
| mock_model.assert_called_once_with( | |
| "test/model", | |
| trust_remote_code=True, | |
| revision="main", | |
| attn_implementation="eager", | |
| dtype=torch.float32, # Should use dtype, not torch_dtype | |
| device_map="cpu", | |
| low_cpu_mem_usage=True | |
| ) | |
| # Verify eval() was called | |
| mock_model_instance.eval.assert_called_once() | |
| assert result is True | |
| assert loader.model == mock_model_instance | |
| if __name__ == "__main__": | |
| pytest.main([__file__]) | |