mbellan's picture
Initial deployment
c3efd49
"""Reproducibility utilities for deterministic training."""
import random
import numpy as np
import torch
import os
from typing import Optional
import logging
logger = logging.getLogger(__name__)
def set_random_seeds(seed: int) -> None:
"""
Set random seeds for all libraries to ensure reproducibility.
Args:
seed: Random seed value
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
logger.info(f"Random seeds set to {seed}")
def set_deterministic_mode(enabled: bool = True) -> None:
"""
Enable or disable deterministic mode for PyTorch operations.
Note: Deterministic mode may reduce performance but ensures reproducibility.
Args:
enabled: Whether to enable deterministic mode
"""
if enabled:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# For PyTorch >= 1.8
if hasattr(torch, 'use_deterministic_algorithms'):
torch.use_deterministic_algorithms(True)
logger.info("Deterministic mode enabled")
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
if hasattr(torch, 'use_deterministic_algorithms'):
torch.use_deterministic_algorithms(False)
logger.info("Deterministic mode disabled")
def get_environment_info() -> dict:
"""
Get information about the execution environment.
Returns:
Dictionary with environment information
"""
import sys
import platform
info = {
'python_version': sys.version,
'platform': platform.platform(),
'pytorch_version': torch.__version__,
'cuda_available': torch.cuda.is_available(),
}
if torch.cuda.is_available():
info['cuda_version'] = torch.version.cuda
info['cudnn_version'] = torch.backends.cudnn.version()
info['gpu_count'] = torch.cuda.device_count()
info['gpu_names'] = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
return info
def log_environment_info() -> None:
"""Log environment information."""
info = get_environment_info()
logger.info("=" * 80)
logger.info("Environment Information:")
logger.info("=" * 80)
for key, value in info.items():
logger.info(f"{key}: {value}")
logger.info("=" * 80)
def setup_reproducibility(seed: int, deterministic: bool = False) -> None:
"""
Set up reproducibility by setting seeds and optionally enabling deterministic mode.
Args:
seed: Random seed value
deterministic: Whether to enable deterministic mode
"""
set_random_seeds(seed)
if deterministic:
set_deterministic_mode(True)
log_environment_info()