File size: 10,316 Bytes
ebc7f2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 |
import argparse
import os
import shutil
import time
from datetime import datetime
from importlib import import_module
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from lightning.pytorch.utilities import rank_zero_info
from omegaconf import OmegaConf
class Config:
def __init__(self, config_path: str = None, override_args: Dict[str, Any] = None):
self.config = OmegaConf.create({})
# Load main config if provided
if config_path:
self.load_yaml(config_path)
if override_args:
self.override_config(override_args)
def load_yaml(self, config_path: str):
"""Load YAML configuration file"""
loaded_config = OmegaConf.load(config_path)
self.config = OmegaConf.merge(self.config, loaded_config)
def override_config(self, override_args: Dict[str, Any]):
"""Handle command line override arguments"""
dotlist = []
for key, value in override_args.items():
# Handle values that might be converted types but should be strings for paths
# The user issue "modify a path having suffix ..yaml" suggests type inference might be wrong
# or splitting logic is wrong.
# Using OmegaConf's standard from_dotlist approach is safest.
# It expects "key=value" strings.
# We need to be careful about value conversion.
# Our _convert_value handles basic types.
val = self._convert_value(value)
# If val is a string, we keep it as is.
# OmegaConf.from_dotlist parses the string again if we pass "key=value".
# But we can construct a config from dict and merge.
# If we use OmegaConf.update(self.config, key, val) it should work for dotted keys.
# However, `update` takes a key and value.
OmegaConf.update(self.config, key, val)
def _convert_value(self, value: str) -> Any:
"""Convert string value to appropriate type"""
if value.lower() == "true":
return True
elif value.lower() == "false":
return False
elif value.lower() == "null":
return None
try:
return int(value)
except ValueError:
try:
return float(value)
except ValueError:
return value
def get(self, key: str, default: Any = None) -> Any:
"""Get configuration value"""
return OmegaConf.select(self.config, key, default=default)
def __getattr__(self, name: str) -> Any:
"""Support dot notation access"""
return self.config[name]
def __getitem__(self, key: str) -> Any:
"""Support dictionary-like access"""
return self.config[key]
def export_config(self, path: str):
"""Export current configuration to file"""
OmegaConf.save(self.config, path)
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", type=str, required=True, help="Path to config file"
)
parser.add_argument(
"--override", type=str, nargs="+", help="Override config values (key=value)"
)
return parser.parse_args()
def load_config(
config_path: Optional[str] = None, override_args: Optional[Dict[str, Any]] = None
) -> Config:
"""Load configuration"""
if config_path is None:
args = parse_args()
config_path = args.config
if args.override:
override_args = {}
for override in args.override:
key, value = override.split("=", 1)
override_args[key.strip()] = value.strip()
return Config(config_path, override_args)
def instantiate(target, cfg=None, hfstyle=False, **init_args):
module_name, class_name = target.rsplit(".", 1)
module = import_module(module_name)
class_ = getattr(module, class_name)
if cfg is None:
return class_(**init_args)
else:
if hfstyle:
config_class = class_.config_class
cfg = config_class(config_obj=cfg)
return class_(cfg, **init_args)
def get_function(target):
module_name, function_name = target.rsplit(".", 1)
module = import_module(module_name)
function_ = getattr(module, function_name)
return function_
def save_config_and_codes(config, save_dir):
os.makedirs(save_dir, exist_ok=True)
sanity_check_dir = os.path.join(save_dir, "sanity_check")
os.makedirs(sanity_check_dir, exist_ok=True)
with open(os.path.join(sanity_check_dir, f"{config.exp_name}.yaml"), "w") as f:
OmegaConf.save(config.config, f)
current_dir = Path.cwd()
exclude_dir = current_dir / "outputs"
for py_file in current_dir.rglob("*.py"):
if exclude_dir in py_file.parents:
continue
dest_path = Path(sanity_check_dir) / py_file.relative_to(current_dir)
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(py_file, dest_path)
def print_model_size(model):
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
rank_zero_info(f"Total parameters: {total_params:,}")
rank_zero_info(f"Trainable parameters: {trainable_params:,}")
rank_zero_info(f"Non-trainable parameters: {(total_params - trainable_params):,}")
def compare_statedict_and_parameters(state_dict, named_parameters, named_buffers):
"""Compare differences between state_dict and parameters"""
# Get all keys in state_dict
state_dict_keys = set(state_dict.keys())
# Get all keys in named_parameters
named_params_keys = set(name for name, _ in named_parameters)
# Find keys that only exist in state_dict
only_in_state_dict = state_dict_keys - named_params_keys
# Find keys that only exist in named_parameters
only_in_named_params = named_params_keys - state_dict_keys
# Print results
if only_in_state_dict:
print(f"Only in state_dict (not in parameters): {sorted(only_in_state_dict)}")
if only_in_named_params:
print(
f"Only in named_parameters (not in state_dict): {sorted(only_in_named_params)}"
)
if not only_in_state_dict and not only_in_named_params:
print("All parameters match between state_dict and named_parameters")
# Additionally compare buffers (non-parameter states, such as BatchNorm's running_mean)
named_buffers_keys = set(name for name, _ in named_buffers)
buffers_only = state_dict_keys - named_params_keys - named_buffers_keys
if buffers_only:
print(
f"Other items in state_dict (neither params nor buffers): {sorted(buffers_only)}"
)
print(f"Total state_dict items: {len(state_dict_keys)}")
print(f"Total named_parameters: {len(named_params_keys)}")
print(f"Total named_buffers: {len(named_buffers_keys)}")
def _resolve_global_rank() -> int:
"""Resolve the global rank from environment variables."""
for key in ("GLOBAL_RANK", "RANK", "SLURM_PROCID", "LOCAL_RANK"):
if key in os.environ:
try:
return int(os.environ[key])
except ValueError:
continue
return 0
def get_shared_run_time(base_dir: str, env_key: str = "PL_RUN_TIME") -> str:
"""
Get a synchronized run time across all processes.
This function ensures all processes (both in distributed training and multi-process
scenarios) use the same timestamp for output directories and experiment tracking.
Args:
base_dir: Base directory for output files
env_key: Environment variable key to cache the run time
Returns:
Synchronized timestamp string in format YYYYMMDD_HHMMSS
"""
cached = os.environ.get(env_key)
if cached:
return cached
timestamp_format = "%Y%m%d_%H%M%S"
if torch.distributed.is_available() and torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
run_time = datetime.now().strftime(timestamp_format)
else:
run_time = None
container = [run_time]
torch.distributed.broadcast_object_list(container, src=0)
run_time = container[0]
if run_time is None:
raise RuntimeError("Failed to synchronize run time across ranks.")
os.environ[env_key] = run_time
return run_time
os.makedirs(base_dir, exist_ok=True)
sync_token = (
os.environ.get("SLURM_JOB_ID")
or os.environ.get("TORCHELASTIC_RUN_ID")
or os.environ.get("JOB_ID")
or "default"
)
sync_dir = os.path.join(base_dir, ".run_time_sync")
os.makedirs(sync_dir, exist_ok=True)
sync_file = os.path.join(sync_dir, f"{sync_token}.txt")
global_rank = _resolve_global_rank()
if global_rank == 0:
# Remove the sync file if it exists to avoid stale reads by other ranks
if os.path.exists(sync_file):
try:
os.remove(sync_file)
except OSError:
pass
run_time = datetime.now().strftime(timestamp_format)
with open(sync_file, "w", encoding="utf-8") as f:
f.write(run_time)
else:
timeout = time.monotonic() + 1200.0
while True:
if os.path.exists(sync_file):
try:
with open(sync_file, "r", encoding="utf-8") as f:
run_time = f.read().strip()
# Check if the timestamp is fresh (within 60 seconds)
# This prevents reading a stale timestamp from a previous run
dt = datetime.strptime(run_time, timestamp_format)
if abs((datetime.now() - dt).total_seconds()) < 60:
break
except (ValueError, OSError):
# File might be empty or partially written, or format mismatch
pass
if time.monotonic() > timeout:
raise TimeoutError(
"Timed out waiting for rank 0 to write synchronized timestamp."
)
time.sleep(0.1)
os.environ[env_key] = run_time
return run_time
|