import argparse import os import shutil import time from datetime import datetime from importlib import import_module from pathlib import Path from typing import Any, Dict, Optional import torch from lightning.pytorch.utilities import rank_zero_info from omegaconf import OmegaConf class Config: def __init__(self, config_path: str = None, override_args: Dict[str, Any] = None): self.config = OmegaConf.create({}) # Load main config if provided if config_path: self.load_yaml(config_path) if override_args: self.override_config(override_args) def load_yaml(self, config_path: str): """Load YAML configuration file""" loaded_config = OmegaConf.load(config_path) self.config = OmegaConf.merge(self.config, loaded_config) def override_config(self, override_args: Dict[str, Any]): """Handle command line override arguments""" dotlist = [] for key, value in override_args.items(): # Handle values that might be converted types but should be strings for paths # The user issue "modify a path having suffix ..yaml" suggests type inference might be wrong # or splitting logic is wrong. # Using OmegaConf's standard from_dotlist approach is safest. # It expects "key=value" strings. # We need to be careful about value conversion. # Our _convert_value handles basic types. val = self._convert_value(value) # If val is a string, we keep it as is. # OmegaConf.from_dotlist parses the string again if we pass "key=value". # But we can construct a config from dict and merge. # If we use OmegaConf.update(self.config, key, val) it should work for dotted keys. # However, `update` takes a key and value. OmegaConf.update(self.config, key, val) def _convert_value(self, value: str) -> Any: """Convert string value to appropriate type""" if value.lower() == "true": return True elif value.lower() == "false": return False elif value.lower() == "null": return None try: return int(value) except ValueError: try: return float(value) except ValueError: return value def get(self, key: str, default: Any = None) -> Any: """Get configuration value""" return OmegaConf.select(self.config, key, default=default) def __getattr__(self, name: str) -> Any: """Support dot notation access""" return self.config[name] def __getitem__(self, key: str) -> Any: """Support dictionary-like access""" return self.config[key] def export_config(self, path: str): """Export current configuration to file""" OmegaConf.save(self.config, path) def parse_args(): """Parse command line arguments""" parser = argparse.ArgumentParser() parser.add_argument( "--config", type=str, required=True, help="Path to config file" ) parser.add_argument( "--override", type=str, nargs="+", help="Override config values (key=value)" ) return parser.parse_args() def load_config( config_path: Optional[str] = None, override_args: Optional[Dict[str, Any]] = None ) -> Config: """Load configuration""" if config_path is None: args = parse_args() config_path = args.config if args.override: override_args = {} for override in args.override: key, value = override.split("=", 1) override_args[key.strip()] = value.strip() return Config(config_path, override_args) def instantiate(target, cfg=None, hfstyle=False, **init_args): module_name, class_name = target.rsplit(".", 1) module = import_module(module_name) class_ = getattr(module, class_name) if cfg is None: return class_(**init_args) else: if hfstyle: config_class = class_.config_class cfg = config_class(config_obj=cfg) return class_(cfg, **init_args) def get_function(target): module_name, function_name = target.rsplit(".", 1) module = import_module(module_name) function_ = getattr(module, function_name) return function_ def save_config_and_codes(config, save_dir): os.makedirs(save_dir, exist_ok=True) sanity_check_dir = os.path.join(save_dir, "sanity_check") os.makedirs(sanity_check_dir, exist_ok=True) with open(os.path.join(sanity_check_dir, f"{config.exp_name}.yaml"), "w") as f: OmegaConf.save(config.config, f) current_dir = Path.cwd() exclude_dir = current_dir / "outputs" for py_file in current_dir.rglob("*.py"): if exclude_dir in py_file.parents: continue dest_path = Path(sanity_check_dir) / py_file.relative_to(current_dir) dest_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(py_file, dest_path) def print_model_size(model): total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) rank_zero_info(f"Total parameters: {total_params:,}") rank_zero_info(f"Trainable parameters: {trainable_params:,}") rank_zero_info(f"Non-trainable parameters: {(total_params - trainable_params):,}") def compare_statedict_and_parameters(state_dict, named_parameters, named_buffers): """Compare differences between state_dict and parameters""" # Get all keys in state_dict state_dict_keys = set(state_dict.keys()) # Get all keys in named_parameters named_params_keys = set(name for name, _ in named_parameters) # Find keys that only exist in state_dict only_in_state_dict = state_dict_keys - named_params_keys # Find keys that only exist in named_parameters only_in_named_params = named_params_keys - state_dict_keys # Print results if only_in_state_dict: print(f"Only in state_dict (not in parameters): {sorted(only_in_state_dict)}") if only_in_named_params: print( f"Only in named_parameters (not in state_dict): {sorted(only_in_named_params)}" ) if not only_in_state_dict and not only_in_named_params: print("All parameters match between state_dict and named_parameters") # Additionally compare buffers (non-parameter states, such as BatchNorm's running_mean) named_buffers_keys = set(name for name, _ in named_buffers) buffers_only = state_dict_keys - named_params_keys - named_buffers_keys if buffers_only: print( f"Other items in state_dict (neither params nor buffers): {sorted(buffers_only)}" ) print(f"Total state_dict items: {len(state_dict_keys)}") print(f"Total named_parameters: {len(named_params_keys)}") print(f"Total named_buffers: {len(named_buffers_keys)}") def _resolve_global_rank() -> int: """Resolve the global rank from environment variables.""" for key in ("GLOBAL_RANK", "RANK", "SLURM_PROCID", "LOCAL_RANK"): if key in os.environ: try: return int(os.environ[key]) except ValueError: continue return 0 def get_shared_run_time(base_dir: str, env_key: str = "PL_RUN_TIME") -> str: """ Get a synchronized run time across all processes. This function ensures all processes (both in distributed training and multi-process scenarios) use the same timestamp for output directories and experiment tracking. Args: base_dir: Base directory for output files env_key: Environment variable key to cache the run time Returns: Synchronized timestamp string in format YYYYMMDD_HHMMSS """ cached = os.environ.get(env_key) if cached: return cached timestamp_format = "%Y%m%d_%H%M%S" if torch.distributed.is_available() and torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: run_time = datetime.now().strftime(timestamp_format) else: run_time = None container = [run_time] torch.distributed.broadcast_object_list(container, src=0) run_time = container[0] if run_time is None: raise RuntimeError("Failed to synchronize run time across ranks.") os.environ[env_key] = run_time return run_time os.makedirs(base_dir, exist_ok=True) sync_token = ( os.environ.get("SLURM_JOB_ID") or os.environ.get("TORCHELASTIC_RUN_ID") or os.environ.get("JOB_ID") or "default" ) sync_dir = os.path.join(base_dir, ".run_time_sync") os.makedirs(sync_dir, exist_ok=True) sync_file = os.path.join(sync_dir, f"{sync_token}.txt") global_rank = _resolve_global_rank() if global_rank == 0: # Remove the sync file if it exists to avoid stale reads by other ranks if os.path.exists(sync_file): try: os.remove(sync_file) except OSError: pass run_time = datetime.now().strftime(timestamp_format) with open(sync_file, "w", encoding="utf-8") as f: f.write(run_time) else: timeout = time.monotonic() + 1200.0 while True: if os.path.exists(sync_file): try: with open(sync_file, "r", encoding="utf-8") as f: run_time = f.read().strip() # Check if the timestamp is fresh (within 60 seconds) # This prevents reading a stale timestamp from a previous run dt = datetime.strptime(run_time, timestamp_format) if abs((datetime.now() - dt).total_seconds()) < 60: break except (ValueError, OSError): # File might be empty or partially written, or format mismatch pass if time.monotonic() > timeout: raise TimeoutError( "Timed out waiting for rank 0 to write synchronized timestamp." ) time.sleep(0.1) os.environ[env_key] = run_time return run_time