Spaces:
Runtime error
Runtime error
File size: 2,933 Bytes
c3efd49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
"""Reproducibility utilities for deterministic training."""
import random
import numpy as np
import torch
import os
from typing import Optional
import logging
logger = logging.getLogger(__name__)
def set_random_seeds(seed: int) -> None:
"""
Set random seeds for all libraries to ensure reproducibility.
Args:
seed: Random seed value
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
logger.info(f"Random seeds set to {seed}")
def set_deterministic_mode(enabled: bool = True) -> None:
"""
Enable or disable deterministic mode for PyTorch operations.
Note: Deterministic mode may reduce performance but ensures reproducibility.
Args:
enabled: Whether to enable deterministic mode
"""
if enabled:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# For PyTorch >= 1.8
if hasattr(torch, 'use_deterministic_algorithms'):
torch.use_deterministic_algorithms(True)
logger.info("Deterministic mode enabled")
else:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
if hasattr(torch, 'use_deterministic_algorithms'):
torch.use_deterministic_algorithms(False)
logger.info("Deterministic mode disabled")
def get_environment_info() -> dict:
"""
Get information about the execution environment.
Returns:
Dictionary with environment information
"""
import sys
import platform
info = {
'python_version': sys.version,
'platform': platform.platform(),
'pytorch_version': torch.__version__,
'cuda_available': torch.cuda.is_available(),
}
if torch.cuda.is_available():
info['cuda_version'] = torch.version.cuda
info['cudnn_version'] = torch.backends.cudnn.version()
info['gpu_count'] = torch.cuda.device_count()
info['gpu_names'] = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
return info
def log_environment_info() -> None:
"""Log environment information."""
info = get_environment_info()
logger.info("=" * 80)
logger.info("Environment Information:")
logger.info("=" * 80)
for key, value in info.items():
logger.info(f"{key}: {value}")
logger.info("=" * 80)
def setup_reproducibility(seed: int, deterministic: bool = False) -> None:
"""
Set up reproducibility by setting seeds and optionally enabling deterministic mode.
Args:
seed: Random seed value
deterministic: Whether to enable deterministic mode
"""
set_random_seeds(seed)
if deterministic:
set_deterministic_mode(True)
log_environment_info()
|