Spaces:
Runtime error
Runtime error
| """Reproducibility utilities for deterministic training.""" | |
| import random | |
| import numpy as np | |
| import torch | |
| import os | |
| from typing import Optional | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def set_random_seeds(seed: int) -> None: | |
| """ | |
| Set random seeds for all libraries to ensure reproducibility. | |
| Args: | |
| seed: Random seed value | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| logger.info(f"Random seeds set to {seed}") | |
| def set_deterministic_mode(enabled: bool = True) -> None: | |
| """ | |
| Enable or disable deterministic mode for PyTorch operations. | |
| Note: Deterministic mode may reduce performance but ensures reproducibility. | |
| Args: | |
| enabled: Whether to enable deterministic mode | |
| """ | |
| if enabled: | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| # For PyTorch >= 1.8 | |
| if hasattr(torch, 'use_deterministic_algorithms'): | |
| torch.use_deterministic_algorithms(True) | |
| logger.info("Deterministic mode enabled") | |
| else: | |
| torch.backends.cudnn.deterministic = False | |
| torch.backends.cudnn.benchmark = True | |
| if hasattr(torch, 'use_deterministic_algorithms'): | |
| torch.use_deterministic_algorithms(False) | |
| logger.info("Deterministic mode disabled") | |
| def get_environment_info() -> dict: | |
| """ | |
| Get information about the execution environment. | |
| Returns: | |
| Dictionary with environment information | |
| """ | |
| import sys | |
| import platform | |
| info = { | |
| 'python_version': sys.version, | |
| 'platform': platform.platform(), | |
| 'pytorch_version': torch.__version__, | |
| 'cuda_available': torch.cuda.is_available(), | |
| } | |
| if torch.cuda.is_available(): | |
| info['cuda_version'] = torch.version.cuda | |
| info['cudnn_version'] = torch.backends.cudnn.version() | |
| info['gpu_count'] = torch.cuda.device_count() | |
| info['gpu_names'] = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] | |
| return info | |
| def log_environment_info() -> None: | |
| """Log environment information.""" | |
| info = get_environment_info() | |
| logger.info("=" * 80) | |
| logger.info("Environment Information:") | |
| logger.info("=" * 80) | |
| for key, value in info.items(): | |
| logger.info(f"{key}: {value}") | |
| logger.info("=" * 80) | |
| def setup_reproducibility(seed: int, deterministic: bool = False) -> None: | |
| """ | |
| Set up reproducibility by setting seeds and optionally enabling deterministic mode. | |
| Args: | |
| seed: Random seed value | |
| deterministic: Whether to enable deterministic mode | |
| """ | |
| set_random_seeds(seed) | |
| if deterministic: | |
| set_deterministic_mode(True) | |
| log_environment_info() | |