File size: 7,919 Bytes
9ed2e4d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
#!/usr/bin/env python3
"""
Repository model inspector.
This script is designed to work in `--config-only` mode without importing
PyTorch/Diffusers/Transformers. It reads JSON configs from a local Diffusers
repository layout and prints a summary.
With `--params`, it can also compute parameter counts by scanning
`*.safetensors` headers (without loading tensor data into RAM).
"""
import argparse
import json
import math
from pathlib import Path
from typing import Any, Dict, Iterable, Optional
def load_json(path: Path) -> Dict[str, Any]:
return json.loads(path.read_text(encoding="utf-8"))
def human_params(value: Optional[int]) -> str:
if value is None:
return "n/a"
if value >= 1_000_000_000:
return f"{value/1e9:.2f}B"
return f"{value/1e6:.2f}M"
def read_model_index(model_dir: Path) -> Dict[str, Any]:
idx_path = model_dir / "model_index.json"
if not idx_path.exists():
return {}
return load_json(idx_path)
def describe_model_index(model_index: Dict[str, Any]) -> None:
if not model_index:
return
print("Pipeline pieces (model_index.json):")
for key, val in model_index.items():
if key.startswith("_"):
continue
print(f" {key:14s} -> {val}")
print()
def detect_pipeline_kind(model_index: Dict[str, Any]) -> str:
cls = str(model_index.get("_class_name", "")).lower()
if "zimage" in cls or ("transformer" in model_index and "unet" not in model_index):
return "zimage"
if "stable" in cls or "unet" in model_index:
return "sdxl_like"
return "unknown"
def iter_safetensors_files(directory: Path) -> Iterable[Path]:
if not directory.exists():
return []
return sorted(p for p in directory.iterdir() if p.is_file() and p.suffix == ".safetensors")
def count_params_from_safetensors(files: Iterable[Path]) -> int:
from safetensors import safe_open
total = 0
for file in files:
with safe_open(str(file), framework="np") as f:
for key in f.keys():
shape = f.get_slice(key).get_shape()
total += math.prod(shape)
return int(total)
def zimage_config_only_summary(model_dir: Path, include_params: bool) -> Dict[str, Any]:
model_index = read_model_index(model_dir)
te_cfg_path = model_dir / "text_encoder" / "config.json"
transformer_cfg_path = model_dir / "transformer" / "config.json"
vae_cfg_path = model_dir / "vae" / "config.json"
scheduler_cfg_path = model_dir / "scheduler" / "scheduler_config.json"
te_cfg = load_json(te_cfg_path) if te_cfg_path.exists() else {}
transformer_cfg = load_json(transformer_cfg_path) if transformer_cfg_path.exists() else {}
vae_cfg = load_json(vae_cfg_path) if vae_cfg_path.exists() else {}
scheduler_cfg = load_json(scheduler_cfg_path) if scheduler_cfg_path.exists() else {}
text_encoder_params = None
transformer_params = None
vae_params = None
if include_params:
text_encoder_params = count_params_from_safetensors(iter_safetensors_files(model_dir / "text_encoder"))
transformer_params = count_params_from_safetensors(iter_safetensors_files(model_dir / "transformer"))
vae_params = count_params_from_safetensors(iter_safetensors_files(model_dir / "vae"))
print("[Text encoder]")
if te_cfg:
arch = te_cfg.get("architectures", [])
arch_name = arch[0] if isinstance(arch, list) and arch else "n/a"
print(f" architecture={arch_name}")
print(
" "
f"layers={te_cfg.get('num_hidden_layers', 'n/a')}, "
f"hidden={te_cfg.get('hidden_size', 'n/a')}, "
f"heads={te_cfg.get('num_attention_heads', 'n/a')}, "
f"intermediate={te_cfg.get('intermediate_size', 'n/a')}"
)
print(f" vocab={te_cfg.get('vocab_size', 'n/a')}, max_positions={te_cfg.get('max_position_embeddings', 'n/a')}")
else:
print(" [warn] missing text_encoder/config.json")
print(f" params={human_params(text_encoder_params)}")
print()
print("[Transformer]")
if transformer_cfg:
print(f" class={transformer_cfg.get('_class_name', 'n/a')}")
print(
" "
f"dim={transformer_cfg.get('dim', 'n/a')}, "
f"layers={transformer_cfg.get('n_layers', 'n/a')}, "
f"heads={transformer_cfg.get('n_heads', 'n/a')}"
)
print(f" in_channels={transformer_cfg.get('in_channels', 'n/a')}, cap_feat_dim={transformer_cfg.get('cap_feat_dim', 'n/a')}")
print(f" patch_size={transformer_cfg.get('all_patch_size', 'n/a')}, f_patch_size={transformer_cfg.get('all_f_patch_size', 'n/a')}")
else:
print(" [warn] missing transformer/config.json")
print(f" params={human_params(transformer_params)}")
print()
print("[VAE]")
if vae_cfg:
print(f" class={vae_cfg.get('_class_name', 'n/a')}")
print(
" "
f"sample_size={vae_cfg.get('sample_size', 'n/a')}, "
f"in_channels={vae_cfg.get('in_channels', 'n/a')}, "
f"latent_channels={vae_cfg.get('latent_channels', 'n/a')}, "
f"out_channels={vae_cfg.get('out_channels', 'n/a')}"
)
print(f" block_out_channels={vae_cfg.get('block_out_channels', 'n/a')}, scaling_factor={vae_cfg.get('scaling_factor', 'n/a')}")
else:
print(" [warn] missing vae/config.json")
print(f" params={human_params(vae_params)}")
print()
print("[Scheduler]")
if scheduler_cfg:
print(
" "
f"class={scheduler_cfg.get('_class_name', 'n/a')}, "
f"timesteps={scheduler_cfg.get('num_train_timesteps', 'n/a')}, "
f"shift={scheduler_cfg.get('shift', 'n/a')}"
)
else:
print(" [warn] missing scheduler/scheduler_config.json")
print()
return {
"kind": "zimage",
"pipeline": model_index,
"text_encoder": {"config": te_cfg, "params": text_encoder_params},
"transformer": {"config": transformer_cfg, "params": transformer_params},
"vae": {"config": vae_cfg, "params": vae_params},
"scheduler": {"config": scheduler_cfg},
}
def main() -> None:
parser = argparse.ArgumentParser(description="Inspect a local Diffusers-style repository layout.")
parser.add_argument("--model-dir", type=Path, default=Path(".."), help="Path to the diffusers pipeline directory.")
parser.add_argument("--device", default="cpu", help="Unused (kept for CLI compatibility).")
parser.add_argument("--fp16", action="store_true", help="Unused (kept for CLI compatibility).")
parser.add_argument("--config-only", action="store_true", help="Read JSON configs and print a summary.")
parser.add_argument("--params", action="store_true", help="Count parameters from *.safetensors headers (no tensor loading).")
parser.add_argument("--json-out", type=Path, default=None, help="Write a JSON summary to this path.")
args = parser.parse_args()
model_index = read_model_index(args.model_dir)
if not model_index:
raise SystemExit(f"model_index.json not found under {args.model_dir}")
describe_model_index(model_index)
kind = detect_pipeline_kind(model_index)
if not args.config_only:
raise SystemExit("Only --config-only mode is supported by this inspector.")
if kind != "zimage":
raise SystemExit(f"Unsupported pipeline kind: {kind} (expected ZImagePipeline-style layout)")
summary = zimage_config_only_summary(args.model_dir, include_params=args.params)
if args.json_out is not None:
args.json_out.parent.mkdir(parents=True, exist_ok=True)
args.json_out.write_text(json.dumps(summary, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
print(f"[info] wrote JSON summary to {args.json_out}")
if __name__ == "__main__":
main()
|