""" Contract tests for Hugging Face API interactions. These tests verify that our code correctly interacts with external HF services. They can be run against real APIs or mocked for CI/CD. """ import pytest import os from unittest.mock import patch, MagicMock import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from scripts.select_revision import RevisionSelector from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError class TestHuggingFaceAPIContract: """Test contracts with Hugging Face API.""" def setup_method(self): """Setup test fixtures.""" self.test_model_id = "microsoft/Phi-3.5-MoE-instruct" self.selector = RevisionSelector(self.test_model_id) @pytest.mark.integration def test_hf_api_connection(self): """Test that we can connect to HF API (requires internet).""" api = HfApi() try: # Try to get model info - this should work for public models model_info = api.model_info(self.test_model_id) assert model_info is not None assert model_info.modelId == self.test_model_id except Exception as e: pytest.skip(f"Cannot connect to HF API: {e}") @patch('huggingface_hub.HfApi.list_repo_commits') def test_get_recent_commits_contract(self, mock_list_commits): """Test contract for getting recent commits.""" # Mock commit objects mock_commits = [ MagicMock(commit_id="abc123"), MagicMock(commit_id="def456"), MagicMock(commit_id="ghi789") ] mock_list_commits.return_value = mock_commits commits = self.selector.get_recent_commits(max_commits=2) # Verify API was called correctly mock_list_commits.assert_called_once_with( repo_id=self.test_model_id, repo_type="model" ) # Verify we got the expected number of commits assert len(commits) == 2 assert commits == ["abc123", "def456"] @patch('huggingface_hub.HfApi.list_repo_commits') def test_get_recent_commits_api_error(self, mock_list_commits): """Test handling of API errors when getting commits.""" mock_list_commits.side_effect = RepositoryNotFoundError("Model not found") commits = self.selector.get_recent_commits() # Should return empty list on error assert commits == [] @patch('huggingface_hub.hf_hub_download') def test_is_cpu_safe_revision_contract(self, mock_download): """Test contract for checking CPU-safe revisions.""" # Mock file content without flash_attn imports mock_file_path = "/tmp/test_modeling.py" mock_download.return_value = mock_file_path # Create mock file content safe_content = """ import torch import torch.nn as nn from transformers import PreTrainedModel class TestModel(PreTrainedModel): def __init__(self, config): super().__init__(config) # No flash_attn imports here """ with patch('builtins.open', create=True) as mock_open: mock_open.return_value.__enter__.return_value.read.return_value = safe_content result = self.selector.is_cpu_safe_revision("abc123") # Verify download was called correctly mock_download.assert_called_once_with( repo_id=self.test_model_id, filename="modeling_phimoe.py", revision="abc123", repo_type="model", cache_dir=".cache" ) assert result is True @patch('huggingface_hub.hf_hub_download') def test_is_cpu_safe_revision_with_flash_attn(self, mock_download): """Test detection of flash_attn imports.""" mock_file_path = "/tmp/test_modeling.py" mock_download.return_value = mock_file_path # Mock file content WITH flash_attn imports unsafe_content = """ import torch import torch.nn as nn import flash_attn from transformers import PreTrainedModel class TestModel(PreTrainedModel): def __init__(self, config): super().__init__(config) """ with patch('builtins.open', create=True) as mock_open: mock_open.return_value.__enter__.return_value.read.return_value = unsafe_content result = self.selector.is_cpu_safe_revision("abc123") assert result is False @patch('huggingface_hub.hf_hub_download') def test_is_cpu_safe_revision_download_error(self, mock_download): """Test handling of download errors.""" mock_download.side_effect = RevisionNotFoundError("Revision not found") result = self.selector.is_cpu_safe_revision("nonexistent") # Should return False on download error assert result is False def test_save_revision_to_env_contract(self): """Test contract for saving revision to .env file.""" test_revision = "abc123def456" # Use a temporary file for testing import tempfile with tempfile.NamedTemporaryFile(mode='w+', suffix='.env', delete=False) as tmp_file: tmp_path = Path(tmp_file.name) try: # Patch the ENV_FILE path with patch('scripts.select_revision.ENV_FILE', tmp_path): self.selector.save_revision_to_env(test_revision) # Verify file was written correctly content = tmp_path.read_text() assert f"HF_REVISION={test_revision}" in content finally: # Clean up if tmp_path.exists(): tmp_path.unlink() def test_save_revision_to_env_existing_file(self): """Test saving revision when .env file already exists.""" test_revision = "new123revision" existing_content = """ # Existing env file SOME_VAR=value HF_REVISION=old123revision OTHER_VAR=other_value """ import tempfile with tempfile.NamedTemporaryFile(mode='w+', suffix='.env', delete=False) as tmp_file: tmp_file.write(existing_content) tmp_file.flush() tmp_path = Path(tmp_file.name) try: with patch('scripts.select_revision.ENV_FILE', tmp_path): self.selector.save_revision_to_env(test_revision) content = tmp_path.read_text() # Should have new revision assert f"HF_REVISION={test_revision}" in content # Should not have old revision assert "HF_REVISION=old123revision" not in content # Should preserve other variables assert "SOME_VAR=value" in content assert "OTHER_VAR=other_value" in content finally: if tmp_path.exists(): tmp_path.unlink() class TestTransformersContract: """Test contracts with transformers library.""" @patch('transformers.AutoTokenizer.from_pretrained') def test_tokenizer_loading_contract(self, mock_tokenizer): """Test contract for tokenizer loading.""" mock_tokenizer_instance = MagicMock() mock_tokenizer.return_value = mock_tokenizer_instance from app.model_loader import ModelLoader loader = ModelLoader() # Create a minimal config from app.config.model_config import ModelConfig import torch loader.config = ModelConfig( model_id="test/model", revision="main", dtype=torch.float32, device_map="cpu", attn_implementation="eager", low_cpu_mem_usage=True, trust_remote_code=True ) result = loader.load_tokenizer() # Verify tokenizer was called with correct parameters mock_tokenizer.assert_called_once_with( "test/model", trust_remote_code=True, revision="main" ) assert result is True assert loader.tokenizer == mock_tokenizer_instance @patch('transformers.AutoModelForCausalLM.from_pretrained') def test_model_loading_contract(self, mock_model): """Test contract for model loading.""" mock_model_instance = MagicMock() mock_model_instance.eval.return_value = mock_model_instance mock_model.return_value = mock_model_instance from app.model_loader import ModelLoader loader = ModelLoader() # Create a minimal config from app.config.model_config import ModelConfig import torch loader.config = ModelConfig( model_id="test/model", revision="main", dtype=torch.float32, device_map="cpu", attn_implementation="eager", low_cpu_mem_usage=True, trust_remote_code=True ) result = loader.load_model() # Verify model was called with correct parameters mock_model.assert_called_once_with( "test/model", trust_remote_code=True, revision="main", attn_implementation="eager", dtype=torch.float32, # Should use dtype, not torch_dtype device_map="cpu", low_cpu_mem_usage=True ) # Verify eval() was called mock_model_instance.eval.assert_called_once() assert result is True assert loader.model == mock_model_instance if __name__ == "__main__": pytest.main([__file__])