File size: 2,933 Bytes
c3efd49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Reproducibility utilities for deterministic training."""
import random
import numpy as np
import torch
import os
from typing import Optional
import logging

logger = logging.getLogger(__name__)


def set_random_seeds(seed: int) -> None:
    """
    Set random seeds for all libraries to ensure reproducibility.
    
    Args:
        seed: Random seed value
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    
    logger.info(f"Random seeds set to {seed}")


def set_deterministic_mode(enabled: bool = True) -> None:
    """
    Enable or disable deterministic mode for PyTorch operations.
    
    Note: Deterministic mode may reduce performance but ensures reproducibility.
    
    Args:
        enabled: Whether to enable deterministic mode
    """
    if enabled:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        # For PyTorch >= 1.8
        if hasattr(torch, 'use_deterministic_algorithms'):
            torch.use_deterministic_algorithms(True)
        logger.info("Deterministic mode enabled")
    else:
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
        if hasattr(torch, 'use_deterministic_algorithms'):
            torch.use_deterministic_algorithms(False)
        logger.info("Deterministic mode disabled")


def get_environment_info() -> dict:
    """
    Get information about the execution environment.
    
    Returns:
        Dictionary with environment information
    """
    import sys
    import platform
    
    info = {
        'python_version': sys.version,
        'platform': platform.platform(),
        'pytorch_version': torch.__version__,
        'cuda_available': torch.cuda.is_available(),
    }
    
    if torch.cuda.is_available():
        info['cuda_version'] = torch.version.cuda
        info['cudnn_version'] = torch.backends.cudnn.version()
        info['gpu_count'] = torch.cuda.device_count()
        info['gpu_names'] = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
    
    return info


def log_environment_info() -> None:
    """Log environment information."""
    info = get_environment_info()
    logger.info("=" * 80)
    logger.info("Environment Information:")
    logger.info("=" * 80)
    for key, value in info.items():
        logger.info(f"{key}: {value}")
    logger.info("=" * 80)


def setup_reproducibility(seed: int, deterministic: bool = False) -> None:
    """
    Set up reproducibility by setting seeds and optionally enabling deterministic mode.
    
    Args:
        seed: Random seed value
        deterministic: Whether to enable deterministic mode
    """
    set_random_seeds(seed)
    if deterministic:
        set_deterministic_mode(True)
    log_environment_info()