File size: 10,316 Bytes
ebc7f2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import argparse
import os
import shutil
import time
from datetime import datetime
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, Optional

import torch
from lightning.pytorch.utilities import rank_zero_info
from omegaconf import OmegaConf


class Config:
    def __init__(self, config_path: str = None, override_args: Dict[str, Any] = None):
        self.config = OmegaConf.create({})

        # Load main config if provided
        if config_path:
            self.load_yaml(config_path)
        if override_args:
            self.override_config(override_args)

    def load_yaml(self, config_path: str):
        """Load YAML configuration file"""
        loaded_config = OmegaConf.load(config_path)
        self.config = OmegaConf.merge(self.config, loaded_config)

    def override_config(self, override_args: Dict[str, Any]):
        """Handle command line override arguments"""
        dotlist = []
        for key, value in override_args.items():
            # Handle values that might be converted types but should be strings for paths
            # The user issue "modify a path having suffix ..yaml" suggests type inference might be wrong
            # or splitting logic is wrong.
            # Using OmegaConf's standard from_dotlist approach is safest.
            # It expects "key=value" strings.
            # We need to be careful about value conversion.
            # Our _convert_value handles basic types.

            val = self._convert_value(value)
            # If val is a string, we keep it as is.
            # OmegaConf.from_dotlist parses the string again if we pass "key=value".
            # But we can construct a config from dict and merge.

            # If we use OmegaConf.update(self.config, key, val) it should work for dotted keys.
            # However, `update` takes a key and value.
            OmegaConf.update(self.config, key, val)

    def _convert_value(self, value: str) -> Any:
        """Convert string value to appropriate type"""
        if value.lower() == "true":
            return True
        elif value.lower() == "false":
            return False
        elif value.lower() == "null":
            return None
        try:
            return int(value)
        except ValueError:
            try:
                return float(value)
            except ValueError:
                return value

    def get(self, key: str, default: Any = None) -> Any:
        """Get configuration value"""
        return OmegaConf.select(self.config, key, default=default)

    def __getattr__(self, name: str) -> Any:
        """Support dot notation access"""
        return self.config[name]

    def __getitem__(self, key: str) -> Any:
        """Support dictionary-like access"""
        return self.config[key]

    def export_config(self, path: str):
        """Export current configuration to file"""
        OmegaConf.save(self.config, path)


def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config", type=str, required=True, help="Path to config file"
    )
    parser.add_argument(
        "--override", type=str, nargs="+", help="Override config values (key=value)"
    )
    return parser.parse_args()


def load_config(
    config_path: Optional[str] = None, override_args: Optional[Dict[str, Any]] = None
) -> Config:
    """Load configuration"""
    if config_path is None:
        args = parse_args()
        config_path = args.config
        if args.override:
            override_args = {}
            for override in args.override:
                key, value = override.split("=", 1)
                override_args[key.strip()] = value.strip()

    return Config(config_path, override_args)


def instantiate(target, cfg=None, hfstyle=False, **init_args):
    module_name, class_name = target.rsplit(".", 1)
    module = import_module(module_name)
    class_ = getattr(module, class_name)
    if cfg is None:
        return class_(**init_args)
    else:
        if hfstyle:
            config_class = class_.config_class
            cfg = config_class(config_obj=cfg)
        return class_(cfg, **init_args)


def get_function(target):
    module_name, function_name = target.rsplit(".", 1)
    module = import_module(module_name)
    function_ = getattr(module, function_name)
    return function_


def save_config_and_codes(config, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    sanity_check_dir = os.path.join(save_dir, "sanity_check")
    os.makedirs(sanity_check_dir, exist_ok=True)
    with open(os.path.join(sanity_check_dir, f"{config.exp_name}.yaml"), "w") as f:
        OmegaConf.save(config.config, f)
    current_dir = Path.cwd()
    exclude_dir = current_dir / "outputs"
    for py_file in current_dir.rglob("*.py"):
        if exclude_dir in py_file.parents:
            continue
        dest_path = Path(sanity_check_dir) / py_file.relative_to(current_dir)
        dest_path.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(py_file, dest_path)


def print_model_size(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    rank_zero_info(f"Total parameters: {total_params:,}")
    rank_zero_info(f"Trainable parameters: {trainable_params:,}")
    rank_zero_info(f"Non-trainable parameters: {(total_params - trainable_params):,}")


def compare_statedict_and_parameters(state_dict, named_parameters, named_buffers):
    """Compare differences between state_dict and parameters"""
    # Get all keys in state_dict
    state_dict_keys = set(state_dict.keys())

    # Get all keys in named_parameters
    named_params_keys = set(name for name, _ in named_parameters)

    # Find keys that only exist in state_dict
    only_in_state_dict = state_dict_keys - named_params_keys

    # Find keys that only exist in named_parameters
    only_in_named_params = named_params_keys - state_dict_keys

    # Print results
    if only_in_state_dict:
        print(f"Only in state_dict (not in parameters): {sorted(only_in_state_dict)}")

    if only_in_named_params:
        print(
            f"Only in named_parameters (not in state_dict): {sorted(only_in_named_params)}"
        )

    if not only_in_state_dict and not only_in_named_params:
        print("All parameters match between state_dict and named_parameters")

    # Additionally compare buffers (non-parameter states, such as BatchNorm's running_mean)
    named_buffers_keys = set(name for name, _ in named_buffers)
    buffers_only = state_dict_keys - named_params_keys - named_buffers_keys

    if buffers_only:
        print(
            f"Other items in state_dict (neither params nor buffers): {sorted(buffers_only)}"
        )

    print(f"Total state_dict items: {len(state_dict_keys)}")
    print(f"Total named_parameters: {len(named_params_keys)}")
    print(f"Total named_buffers: {len(named_buffers_keys)}")


def _resolve_global_rank() -> int:
    """Resolve the global rank from environment variables."""
    for key in ("GLOBAL_RANK", "RANK", "SLURM_PROCID", "LOCAL_RANK"):
        if key in os.environ:
            try:
                return int(os.environ[key])
            except ValueError:
                continue
    return 0


def get_shared_run_time(base_dir: str, env_key: str = "PL_RUN_TIME") -> str:
    """
    Get a synchronized run time across all processes.

    This function ensures all processes (both in distributed training and multi-process
    scenarios) use the same timestamp for output directories and experiment tracking.

    Args:
        base_dir: Base directory for output files
        env_key: Environment variable key to cache the run time

    Returns:
        Synchronized timestamp string in format YYYYMMDD_HHMMSS
    """
    cached = os.environ.get(env_key)
    if cached:
        return cached

    timestamp_format = "%Y%m%d_%H%M%S"

    if torch.distributed.is_available() and torch.distributed.is_initialized():
        if torch.distributed.get_rank() == 0:
            run_time = datetime.now().strftime(timestamp_format)
        else:
            run_time = None
        container = [run_time]
        torch.distributed.broadcast_object_list(container, src=0)
        run_time = container[0]
        if run_time is None:
            raise RuntimeError("Failed to synchronize run time across ranks.")
        os.environ[env_key] = run_time
        return run_time

    os.makedirs(base_dir, exist_ok=True)
    sync_token = (
        os.environ.get("SLURM_JOB_ID")
        or os.environ.get("TORCHELASTIC_RUN_ID")
        or os.environ.get("JOB_ID")
        or "default"
    )
    sync_dir = os.path.join(base_dir, ".run_time_sync")
    os.makedirs(sync_dir, exist_ok=True)
    sync_file = os.path.join(sync_dir, f"{sync_token}.txt")

    global_rank = _resolve_global_rank()
    if global_rank == 0:
        # Remove the sync file if it exists to avoid stale reads by other ranks
        if os.path.exists(sync_file):
            try:
                os.remove(sync_file)
            except OSError:
                pass

        run_time = datetime.now().strftime(timestamp_format)
        with open(sync_file, "w", encoding="utf-8") as f:
            f.write(run_time)
    else:
        timeout = time.monotonic() + 1200.0
        while True:
            if os.path.exists(sync_file):
                try:
                    with open(sync_file, "r", encoding="utf-8") as f:
                        run_time = f.read().strip()
                    # Check if the timestamp is fresh (within 60 seconds)
                    # This prevents reading a stale timestamp from a previous run
                    dt = datetime.strptime(run_time, timestamp_format)
                    if abs((datetime.now() - dt).total_seconds()) < 60:
                        break
                except (ValueError, OSError):
                    # File might be empty or partially written, or format mismatch
                    pass

            if time.monotonic() > timeout:
                raise TimeoutError(
                    "Timed out waiting for rank 0 to write synchronized timestamp."
                )
            time.sleep(0.1)

    os.environ[env_key] = run_time
    return run_time