iljung1106 commited on
Commit
07f1b5a
·
1 Parent(s): 89e6d19

add ddp py

Browse files
Files changed (1) hide show
  1. scripts/train_style_ddp.py +1268 -0
scripts/train_style_ddp.py ADDED
@@ -0,0 +1,1268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os, re, math, random, glob, time, subprocess, sys, zlib, gc, warnings, atexit
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from datetime import datetime
8
+ from typing import Optional, Dict, List
9
+
10
+ import numpy as np
11
+ from PIL import Image, ImageFile
12
+ from PIL.Image import DecompressionBombWarning
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torch.multiprocessing as mp
18
+ import torch.distributed as dist
19
+
20
+ from torch.utils.data import Dataset, DataLoader, Sampler
21
+ from torchvision import transforms
22
+
23
+ # tqdm (auto-install if missing)
24
+ try:
25
+ from tqdm.auto import tqdm
26
+ except Exception:
27
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "tqdm"])
28
+ from tqdm.auto import tqdm
29
+
30
+ # ------------------------- Config -------------------------
31
+ @dataclass
32
+ class Cfg:
33
+ data_root: str = "./"
34
+ folders: dict = None
35
+ stages: list = None
36
+ P: int = 16
37
+ K: int = 2
38
+ embed_dim: int = 256
39
+ workers: int = 8
40
+ weight_decay: float = 0.01
41
+ alpha_proxy: float = 32.0
42
+ margin_proxy: float = 0.2
43
+ supcon_tau: float = 0.07
44
+ mv_tau: float = 0.10
45
+ mixstyle_p: float = 0.10
46
+ out_dir: str = "./checkpoints_style"
47
+ seed: int = 1337
48
+ max_steps_per_epoch: Optional[int] = None # None이면 데이터 길이에 따라 자동
49
+ print_every: int = 50
50
+ use_compile: bool = False
51
+
52
+ cfg = Cfg(
53
+ folders=dict(whole="dataset", face="dataset_face", eyes="dataset_eyes"),
54
+ stages=[
55
+ dict(sz_whole=224, sz_face=192, sz_eyes=128, epochs=12, lr=3e-4, P=64, K=2),
56
+ dict(sz_whole=384, sz_face=320, sz_eyes=192, epochs=12, lr=1.5e-4, P=24, K=2),
57
+ dict(sz_whole=512, sz_face=384, sz_eyes=224, epochs=24, lr=8e-5, P=12, K=2),
58
+ ],
59
+ )
60
+
61
+ # ------------------------- Device & determinism -------------------------
62
+ def seed_all(seed: int):
63
+ random.seed(seed)
64
+ np.random.seed(seed)
65
+ torch.manual_seed(seed)
66
+ if torch.cuda.is_available():
67
+ torch.cuda.manual_seed_all(seed)
68
+
69
+ seed_all(cfg.seed)
70
+
71
+ torch.backends.cuda.matmul.allow_tf32 = True
72
+ torch.backends.cudnn.allow_tf32 = True
73
+ torch.backends.cudnn.benchmark = True
74
+ if hasattr(torch, "set_float32_matmul_precision"):
75
+ torch.set_float32_matmul_precision("high")
76
+
77
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
78
+ amp_dtype = torch.bfloat16
79
+ else:
80
+ amp_dtype = torch.float16
81
+
82
+ # --- PIL safety/verbosity tweaks ---
83
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
84
+ Image.MAX_IMAGE_PIXELS = 300_000_000
85
+ warnings.filterwarnings("ignore", category=DecompressionBombWarning)
86
+ warnings.filterwarnings("ignore", category=UserWarning, module="PIL.TiffImagePlugin")
87
+
88
+ # ------------------------- Robust multiprocessing for DataLoader -------------------------
89
+ def _init_mp_ctx():
90
+ method = mp.get_start_method(allow_none=True)
91
+ if method is None:
92
+ preferred = 'fork' if sys.platform.startswith('linux') else 'spawn'
93
+ try:
94
+ mp.set_start_method(preferred, force=True)
95
+ except Exception:
96
+ pass
97
+ method = mp.get_start_method(allow_none=True) or preferred
98
+ print(f"[mp] using '{method}'.")
99
+ return mp.get_context(method)
100
+
101
+ MP_CTX = _init_mp_ctx()
102
+
103
+ _DL_TRACK = []
104
+ def _track_dl(dl):
105
+ _DL_TRACK.append(dl); return dl
106
+
107
+ def _close_dl(dl):
108
+ try:
109
+ it = getattr(dl, "_iterator", None)
110
+ if it is not None:
111
+ it._shutdown_workers()
112
+ dl._iterator = None
113
+ except Exception:
114
+ pass
115
+
116
+ @atexit.register
117
+ def _cleanup_all_dls():
118
+ for dl in list(_DL_TRACK):
119
+ _close_dl(dl)
120
+ _DL_TRACK.clear()
121
+
122
+ def _should_fallback_workers(err: Exception) -> bool:
123
+ s = str(err)
124
+ return ("Can't get attribute" in s or
125
+ "PicklingError" in s or
126
+ ("AttributeError" in s and "__main__" in s))
127
+
128
+ # ------------------------- Helpers -------------------------
129
+ def stable_int(s: str) -> int:
130
+ return zlib.adler32(s.encode("utf-8")) & 0xffffffff
131
+
132
+ def l2n(x, eps=1e-8):
133
+ return F.normalize(x, dim=-1, eps=eps)
134
+
135
+ # ------------------------- Dataset -------------------------
136
+ class TriViewDataset(Dataset):
137
+ """
138
+ - whole / face / eyes 각각에 대해 9:1로 train/val split (경로 해시 기반).
139
+ - __getitem__에서는 해당 작가의 view pool에서 랜덤으로 뽑아서 tri-view 구성.
140
+ - 파일명 매칭 전혀 사용 X, 작가(label)만 동일하면 아무 이미지나 조합.
141
+ - index는 whole 기반으로 만들고, label/gid/path 는 whole 기준.
142
+ """
143
+
144
+ def __init__(self, root, folders, split="train",
145
+ T_whole=None, T_face=None, T_eyes=None):
146
+ assert split in ("train", "val")
147
+ self.split = split
148
+ self.root = Path(root)
149
+ self.dirs = {k: self.root / v for k, v in folders.items()}
150
+ self.T = dict(whole=T_whole, face=T_face, eyes=T_eyes)
151
+
152
+ # artist 목록
153
+ whole_root = self.dirs["whole"]
154
+ artists = sorted([d.name for d in whole_root.iterdir() if d.is_dir()])
155
+ self.artist2id = {a: i for i, a in enumerate(artists)}
156
+ self.id2artist = {v: k for k, v in self.artist2id.items()}
157
+ self.num_classes = len(self.artist2id)
158
+
159
+ # artist별 view pool (split 별)
160
+ self.whole_paths_by_artist: Dict[int, List[Path]] = {aid: [] for aid in self.id2artist.keys()}
161
+ self.face_paths_by_artist: Dict[int, List[Path]] = {aid: [] for aid in self.id2artist.keys()}
162
+ self.eyes_paths_by_artist: Dict[int, List[Path]] = {aid: [] for aid in self.id2artist.keys()}
163
+
164
+ def view_split(paths: List[Path], split: str) -> List[Path]:
165
+ train_list, val_list = [], []
166
+ for p in paths:
167
+ h = stable_int(str(p)) % 10
168
+ if split == "train":
169
+ if h < 9: # 0~8 => train
170
+ train_list.append(p)
171
+ else:
172
+ if h >= 9: # 9 => val
173
+ val_list.append(p)
174
+ return train_list if split == "train" else val_list
175
+
176
+ # whole / face / eyes 각각에 대해 artist별 split
177
+ for artist_name, aid in self.artist2id.items():
178
+ # whole
179
+ w_dir = self.dirs["whole"] / artist_name
180
+ if w_dir.is_dir():
181
+ w_all = sorted([p for p in w_dir.iterdir() if p.is_file()])
182
+ else:
183
+ w_all = []
184
+ self.whole_paths_by_artist[aid] = view_split(w_all, split)
185
+
186
+ # face
187
+ f_dir = self.dirs["face"] / artist_name
188
+ if f_dir.is_dir():
189
+ f_all = sorted([p for p in f_dir.iterdir() if p.is_file()])
190
+ else:
191
+ f_all = []
192
+ self.face_paths_by_artist[aid] = view_split(f_all, split)
193
+
194
+ # eyes
195
+ e_dir = self.dirs["eyes"] / artist_name
196
+ if e_dir.is_dir():
197
+ e_all = sorted([p for p in e_dir.iterdir() if p.is_file()])
198
+ else:
199
+ e_all = []
200
+ self.eyes_paths_by_artist[aid] = view_split(e_all, split)
201
+
202
+ # index: whole 기반 anchor
203
+ self.index = []
204
+ for aid, w_list in self.whole_paths_by_artist.items():
205
+ for wp in w_list:
206
+ rec = {
207
+ "label": aid,
208
+ "whole": str(wp),
209
+ "gid": stable_int(str(wp)),
210
+ "path": str(wp),
211
+ }
212
+ self.index.append(rec)
213
+
214
+ def __len__(self):
215
+ return len(self.index)
216
+
217
+ def _load_one(self, path: Optional[Path], T):
218
+ if path is None:
219
+ return None
220
+ try:
221
+ im = Image.open(path).convert("RGB")
222
+ except Exception:
223
+ return None
224
+ if T is not None:
225
+ return T(im)
226
+ else:
227
+ return transforms.ToTensor()(im)
228
+
229
+ def __getitem__(self, i):
230
+ rec = self.index[i]
231
+ aid = rec["label"]
232
+
233
+ W_pool = self.whole_paths_by_artist.get(aid, [])
234
+ F_pool = self.face_paths_by_artist.get(aid, [])
235
+ E_pool = self.eyes_paths_by_artist.get(aid, [])
236
+
237
+ pw = random.choice(W_pool) if W_pool else None
238
+ pf = random.choice(F_pool) if F_pool else None
239
+ pe = random.choice(E_pool) if E_pool else None
240
+
241
+ xw = self._load_one(pw, self.T["whole"]) if pw is not None else None
242
+ xf = self._load_one(pf, self.T["face"]) if pf is not None else None
243
+ xe = self._load_one(pe, self.T["eyes"]) if pe is not None else None
244
+
245
+ gid = torch.tensor([rec["gid"]], dtype=torch.long)
246
+ return dict(
247
+ whole=xw,
248
+ face=xf,
249
+ eyes=xe,
250
+ label=torch.tensor(aid, dtype=torch.long),
251
+ gid=gid,
252
+ path=rec["path"],
253
+ )
254
+
255
+ # ------------------------- PK batch sampler -------------------------
256
+ class PKBatchSampler(Sampler):
257
+ """P개 클래스 × K개 이미지를 한 배치로 뽑는 샘플러."""
258
+ def __init__(self, dataset: TriViewDataset, P: int, K: int):
259
+ self.P, self.K = int(P), int(K)
260
+ from collections import defaultdict
261
+ self.by_cls = defaultdict(list)
262
+ for idx, rec in enumerate(dataset.index):
263
+ self.by_cls[rec["label"]].append(idx)
264
+ self.labels = list(self.by_cls.keys())
265
+ for lst in self.by_cls.values():
266
+ random.shuffle(lst)
267
+
268
+ def __iter__(self):
269
+ while True:
270
+ P, K = self.P, self.K
271
+ if len(self.labels) >= P:
272
+ classes = random.sample(self.labels, P)
273
+ else:
274
+ classes = random.choices(self.labels, k=P)
275
+ batch = []
276
+ for c in classes:
277
+ pool = self.by_cls[c]
278
+ if len(pool) >= K:
279
+ picks = random.sample(pool, K)
280
+ else:
281
+ picks = [random.choice(pool) for _ in range(K)]
282
+ batch.extend(picks)
283
+ yield batch
284
+
285
+ def __len__(self): # not used
286
+ return 10**9
287
+
288
+ # ------------------------- Collate & transforms -------------------------
289
+ def collate_triview(batch):
290
+ labels = torch.stack([b["label"] for b in batch])
291
+ gids = torch.stack([b["gid"] for b in batch]).squeeze(1)
292
+ paths = [b["path"] for b in batch]
293
+ views, masks = {}, {}
294
+ for k in ("whole", "face", "eyes"):
295
+ xs = [b[k] for b in batch]
296
+ mask = torch.tensor([x is not None for x in xs], dtype=torch.bool)
297
+ if any(mask):
298
+ ex = next(x for x in xs if x is not None)
299
+ zeros = torch.zeros_like(ex)
300
+ xs = [x if x is not None else zeros for x in xs]
301
+ views[k] = torch.stack(xs, dim=0)
302
+ else:
303
+ views[k] = None
304
+ masks[k] = mask
305
+ return dict(views=views, masks=masks, labels=labels, gids=gids, paths=paths)
306
+
307
+ def make_transforms(sz_w, sz_f, sz_e):
308
+ def aug(s):
309
+ return transforms.Compose([
310
+ transforms.RandomResizedCrop(s, scale=(0.6, 1.0)),
311
+ transforms.RandomHorizontalFlip(),
312
+ transforms.ColorJitter(brightness=0.1, contrast=0.1,
313
+ saturation=0.05, hue=0.02),
314
+ transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3),
315
+ transforms.ToTensor(),
316
+ transforms.Normalize([0.5]*3, [0.5]*3),
317
+ ])
318
+ return aug(sz_w), aug(sz_f), aug(sz_e)
319
+
320
+ def make_val_transforms(sz_w, sz_f, sz_e):
321
+ def val(s):
322
+ return transforms.Compose([
323
+ transforms.Resize(int(s*1.15)),
324
+ transforms.CenterCrop(s),
325
+ transforms.ToTensor(),
326
+ transforms.Normalize([0.5]*3, [0.5]*3),
327
+ ])
328
+ return val(sz_w), val(sz_f), val(sz_e)
329
+
330
+ # ------------------------- Model & heads -------------------------
331
+ class MixStyle(nn.Module):
332
+ def __init__(self, p=0.3, alpha=0.1):
333
+ super().__init__()
334
+ self.p = p; self.alpha = alpha
335
+ def forward(self, x):
336
+ if not self.training or self.p <= 0.0:
337
+ return x
338
+ B,C,H,W = x.shape
339
+ mu = x.mean([2,3], keepdim=True)
340
+ var = x.var([2,3], unbiased=False, keepdim=True)
341
+ sigma = (var+1e-5).sqrt()
342
+ perm = torch.randperm(B, device=x.device)
343
+ mu2, sigma2 = mu[perm], sigma[perm]
344
+ lam = torch.distributions.Beta(self.alpha, self.alpha).sample((B,1,1,1)).to(x.device)
345
+ mu_mix = mu*lam + mu2*(1-lam)
346
+ sigma_mix = sigma*lam + sigma2*(1-lam)
347
+ x_norm = (x - mu)/sigma
348
+ apply = (torch.rand(B,1,1,1, device=x.device) < self.p).float()
349
+ mixed = x_norm * sigma_mix + mu_mix
350
+ return mixed*apply + x*(1-apply)
351
+
352
+ class SqueezeExcite(nn.Module):
353
+ def __init__(self, c, r=16):
354
+ super().__init__()
355
+ m = max(8, c//r)
356
+ self.net = nn.Sequential(
357
+ nn.AdaptiveAvgPool2d(1),
358
+ nn.Conv2d(c, m, 1), nn.GELU(),
359
+ nn.Conv2d(m, c, 1), nn.Sigmoid()
360
+ )
361
+ def forward(self, x):
362
+ return x * self.net(x)
363
+
364
+ class ConvBlock(nn.Module):
365
+ def __init__(self, ci, co, k=3, s=1, p=1):
366
+ super().__init__()
367
+ self.conv = nn.Conv2d(ci, co, k, s, p, bias=False)
368
+ self.gn = nn.GroupNorm(16, co)
369
+ self.act = nn.GELU()
370
+ def forward(self, x):
371
+ return self.act(self.gn(self.conv(x)))
372
+
373
+ class ResBlock(nn.Module):
374
+ def __init__(self, ci, co, down=False, mix=None):
375
+ super().__init__()
376
+ self.c1 = ConvBlock(ci, co, 3, 1, 1)
377
+ self.c2 = ConvBlock(co, co, 3, 1, 1)
378
+ self.se = SqueezeExcite(co)
379
+ self.down = down
380
+ self.pool = nn.AvgPool2d(2) if down else nn.Identity()
381
+ self.proj = nn.Conv2d(ci, co, 1, 1, 0, bias=False) if ci != co else nn.Identity()
382
+ self.mix = mix
383
+ def forward(self, x):
384
+ h = self.c1(x)
385
+ if self.mix is not None:
386
+ h = self.mix(h)
387
+ h = self.c2(h)
388
+ h = self.se(h)
389
+ if self.down:
390
+ h = self.pool(h); x = self.pool(x)
391
+ return F.gelu(h + self.proj(x))
392
+
393
+ def matrix_sqrt_newton_schulz(A, iters=5):
394
+ B,C,_ = A.shape
395
+ normA = A.reshape(B, -1).norm(dim=1).view(B,1,1).clamp(min=1e-8)
396
+ Y = A / normA
397
+ I = torch.eye(C, device=A.device).expand(B, C, C)
398
+ Z = I.clone()
399
+ for _ in range(iters):
400
+ T = 0.5 * (3.0*I - Z.bmm(Y))
401
+ Y = Y.bmm(T)
402
+ Z = T.bmm(Z)
403
+ return Y * (normA.sqrt())
404
+
405
+ class GramHead(nn.Module):
406
+ def __init__(self, c_in, c_red=64, proj=128):
407
+ super().__init__()
408
+ self.red = nn.Conv2d(c_in, c_red, 1, bias=False)
409
+ self.proj = nn.Linear(c_red*c_red, proj)
410
+ def forward(self, x):
411
+ f = self.red(x)
412
+ B,C,H,W = f.shape
413
+ Fm = f.flatten(2)
414
+ G = torch.bmm(Fm, Fm.transpose(1,2)) / (H*W)
415
+ return self.proj(G.reshape(B, C*C))
416
+
417
+ class CovISqrtHead(nn.Module):
418
+ def __init__(self, c_in, c_red=64, proj=128):
419
+ super().__init__()
420
+ self.red = nn.Conv2d(c_in, c_red, 1, bias=False)
421
+ self.proj = nn.Linear(c_red*c_red, proj)
422
+ def forward(self, x):
423
+ with torch.amp.autocast('cuda', enabled=False):
424
+ f = self.red(x.float())
425
+ B,C,H,W = f.shape
426
+ Fm = f.flatten(2)
427
+ mu = Fm.mean(-1, keepdim=True)
428
+ Xc = Fm - mu
429
+ cov = torch.bmm(Xc, Xc.transpose(1,2)) / (H*W - 1 + 1e-5)
430
+ cov = matrix_sqrt_newton_schulz(cov.float(), iters=5)
431
+ return self.proj(cov.reshape(B, C*C))
432
+
433
+ def spectrum_hist(x, K=16, O=8):
434
+ B,C,H,W = x.shape
435
+ spec = torch.fft.rfft2(x, norm='ortho').abs().mean(1)
436
+ H2, W2 = spec.shape[-2], spec.shape[-1]
437
+ yy, xx = torch.meshgrid(
438
+ torch.linspace(-1, 1, H2, device=x.device),
439
+ torch.linspace(0, 1, W2, device=x.device),
440
+ indexing="ij"
441
+ )
442
+ rr = (yy**2 + xx**2).sqrt().clamp(0, 1 - 1e-8)
443
+ th = (torch.atan2(yy, xx + 1e-9) + math.pi/2)
444
+ rb = (rr * K).long().clamp(0, K-1)
445
+ ob = (th / math.pi * O).long().clamp(0, O-1)
446
+ mag = torch.log1p(spec)
447
+ rad = torch.zeros(B, K, device=x.device)
448
+ ang = torch.zeros(B, O, device=x.device)
449
+ rbf = rb.reshape(-1); obf = ob.reshape(-1)
450
+ for b in range(B):
451
+ m = mag[b].reshape(-1)
452
+ rad[b].scatter_add_(0, rbf, m)
453
+ ang[b].scatter_add_(0, obf, m)
454
+ rad = rad / (rad.sum(-1, keepdim=True)+1e-6)
455
+ ang = ang / (ang.sum(-1, keepdim=True)+1e-6)
456
+ return torch.cat([rad, ang], dim=1)
457
+
458
+ class SpectrumHead(nn.Module):
459
+ def __init__(self, c_in, proj=64, K=16, O=8):
460
+ super().__init__()
461
+ self.proj = nn.Linear(K+O, proj)
462
+ def forward(self, x):
463
+ with torch.amp.autocast('cuda', enabled=False):
464
+ h = spectrum_hist(x.float())
465
+ return self.proj(h)
466
+
467
+ class StatsHead(nn.Module):
468
+ def __init__(self, c_in, proj=64):
469
+ super().__init__()
470
+ c = min(64, c_in)
471
+ self.red = nn.Conv2d(c_in, c, 1, bias=False)
472
+ self.mlp = nn.Sequential(
473
+ nn.Linear(c*2, 128),
474
+ nn.GELU(),
475
+ nn.Linear(128, proj),
476
+ )
477
+ def forward(self, x):
478
+ f = self.red(x)
479
+ mu = f.mean([2,3])
480
+ lv = torch.log(f.var([2,3], unbiased=False)+1e-5)
481
+ return self.mlp(torch.cat([mu, lv], dim=1))
482
+
483
+ class ViewEncoder(nn.Module):
484
+ """
485
+ - Normalize([0.5],[0.5])된 RGB 입력
486
+ - RGB -> Lab 변환
487
+ - backbone + 스타일 헤드 4개 (Gram/Cov/Spectrum/Stats)
488
+ - 브랜치 attention
489
+ """
490
+ def __init__(self, mix_p=0.3, out_dim=256):
491
+ super().__init__()
492
+ self.mix = MixStyle(p=mix_p, alpha=0.1)
493
+ ch = [32, 64, 128, 192, 256]
494
+
495
+ self.stem = nn.Sequential(
496
+ ConvBlock(3, ch[0], 3, 1, 1),
497
+ ConvBlock(ch[0], ch[0], 3, 1, 1),
498
+ )
499
+ self.b1 = ResBlock(ch[0], ch[1], down=True, mix=self.mix)
500
+ self.b2 = ResBlock(ch[1], ch[2], down=True, mix=self.mix)
501
+ self.b3 = ResBlock(ch[2], ch[3], down=True, mix=None)
502
+ self.b4 = ResBlock(ch[3], ch[4], down=True, mix=None)
503
+
504
+ self.h_gram3 = GramHead(ch[3])
505
+ self.h_cov3 = CovISqrtHead(ch[3])
506
+ self.h_sp3 = SpectrumHead(ch[3])
507
+ self.h_st3 = StatsHead(ch[3])
508
+
509
+ self.h_gram4 = GramHead(ch[4])
510
+ self.h_cov4 = CovISqrtHead(ch[4])
511
+ self.h_sp4 = SpectrumHead(ch[4])
512
+ self.h_st4 = StatsHead(ch[4])
513
+
514
+ fdim = (128+128+64+64)*2 # 768
515
+ self.fdim = fdim
516
+
517
+ self.branch_gate = nn.Sequential(
518
+ nn.LayerNorm(fdim),
519
+ nn.Linear(fdim, 4, bias=True),
520
+ )
521
+
522
+ self.fuse = nn.Sequential(
523
+ nn.Linear(fdim, 512),
524
+ nn.GELU(),
525
+ nn.Linear(512, out_dim),
526
+ )
527
+
528
+ def _rgb_to_lab(self, x: torch.Tensor) -> torch.Tensor:
529
+ with torch.amp.autocast('cuda', enabled=False):
530
+ x_f = x.float()
531
+ rgb = (x_f * 0.5 + 0.5).clamp(0.0, 1.0)
532
+
533
+ thresh = 0.04045
534
+ low = rgb / 12.92
535
+ high = ((rgb + 0.055) / 1.055).pow(2.4)
536
+ rgb_lin = torch.where(rgb <= thresh, low, high)
537
+
538
+ rgb_lin = rgb_lin.permute(0, 2, 3, 1)
539
+ M = rgb_lin.new_tensor([
540
+ [0.4124564, 0.3575761, 0.1804375],
541
+ [0.2126729, 0.7151522, 0.0721750],
542
+ [0.0193339, 0.1191920, 0.9503041],
543
+ ])
544
+ xyz = torch.matmul(rgb_lin, M.T)
545
+
546
+ Xn, Yn, Zn = 0.95047, 1.00000, 1.08883
547
+ xyz = xyz / rgb_lin.new_tensor([Xn, Yn, Zn])
548
+
549
+ eps = 0.008856
550
+ kappa = 903.3
551
+
552
+ def f(t):
553
+ t = t.clamp(min=1e-6)
554
+ return torch.where(
555
+ t > eps,
556
+ t.pow(1.0 / 3.0),
557
+ (kappa * t + 16.0) / 116.0,
558
+ )
559
+
560
+ f_xyz = f(xyz)
561
+ fx, fy, fz = f_xyz[..., 0], f_xyz[..., 1], f_xyz[..., 2]
562
+
563
+ L = 116.0 * fy - 16.0
564
+ a = 500.0 * (fx - fy)
565
+ b = 200.0 * (fy - fz)
566
+
567
+ L_scaled = L / 100.0
568
+ a_scaled = (a + 128.0) / 255.0
569
+ b_scaled = (b + 128.0) / 255.0
570
+
571
+ lab = torch.stack([L_scaled, a_scaled, b_scaled], dim=-1)
572
+ lab = lab.permute(0, 3, 1, 2)
573
+
574
+ return lab.to(dtype=x.dtype)
575
+
576
+ def forward(self, x):
577
+ x_lab = self._rgb_to_lab(x)
578
+
579
+ f0 = self.stem(x_lab)
580
+ f1 = self.b1(f0)
581
+ f2 = self.b2(f1)
582
+ f3 = self.b3(f2)
583
+ f4 = self.b4(f3)
584
+
585
+ g3 = self.h_gram3(f3)
586
+ c3 = self.h_cov3(f3)
587
+ sp3 = self.h_sp3(f3)
588
+ st3 = self.h_st3(f3)
589
+
590
+ g4 = self.h_gram4(f4)
591
+ c4 = self.h_cov4(f4)
592
+ sp4 = self.h_sp4(f4)
593
+ st4 = self.h_st4(f4)
594
+
595
+ b_gram = torch.cat([g3, g4], dim=1)
596
+ b_cov = torch.cat([c3, c4], dim=1)
597
+ b_sp = torch.cat([sp3, sp4], dim=1)
598
+ b_st = torch.cat([st3, st4], dim=1)
599
+
600
+ flat = torch.cat([b_gram, b_cov, b_sp, b_st], dim=1) # [B,768]
601
+
602
+ gate_logits = self.branch_gate(flat)
603
+ w = torch.softmax(gate_logits, dim=-1)
604
+ w0, w1, w2, w3 = w[:,0:1], w[:,1:2], w[:,2:3], w[:,3:4]
605
+
606
+ flat_weighted = torch.cat([
607
+ b_gram * w0,
608
+ b_cov * w1,
609
+ b_sp * w2,
610
+ b_st * w3,
611
+ ], dim=1)
612
+
613
+ view_vec = self.fuse(flat_weighted)
614
+ return view_vec
615
+
616
+ class TriViewStyleNet(nn.Module):
617
+ def __init__(self, out_dim=256, mix_p=0.3, share_backbone: bool = True):
618
+ super().__init__()
619
+ if share_backbone:
620
+ shared = ViewEncoder(mix_p=mix_p, out_dim=out_dim)
621
+ self.enc_whole = shared
622
+ self.enc_face = shared
623
+ self.enc_eyes = shared
624
+ else:
625
+ self.enc_whole = ViewEncoder(mix_p=mix_p, out_dim=out_dim)
626
+ self.enc_face = ViewEncoder(mix_p=mix_p, out_dim=out_dim)
627
+ self.enc_eyes = ViewEncoder(mix_p=mix_p, out_dim=out_dim)
628
+ self.view_gate = nn.Sequential(
629
+ nn.LayerNorm(out_dim),
630
+ nn.Linear(out_dim, 1, bias=True),
631
+ )
632
+ def forward(self, views, masks):
633
+ outs, alphas = {}, []
634
+ for k, enc in (("whole", self.enc_whole),
635
+ ("face", self.enc_face),
636
+ ("eyes", self.enc_eyes)):
637
+ if views[k] is None:
638
+ outs[k] = None
639
+ alphas.append(None)
640
+ continue
641
+ vk = enc(views[k].to(memory_format=torch.channels_last))
642
+ outs[k] = l2n(vk)
643
+ score = self.view_gate(outs[k]).squeeze(1)
644
+ score = torch.where(
645
+ masks[k].to(score.device),
646
+ score,
647
+ torch.full_like(score, -1e4),
648
+ )
649
+ alphas.append(score)
650
+ scores = [a for a in alphas if a is not None]
651
+ if len(scores) == 0:
652
+ raise RuntimeError("All views are missing in this batch.")
653
+ A = torch.stack(scores, dim=1) # [B, num_views]
654
+ W = F.softmax(A, dim=1)
655
+ present = [outs[k] for k in ("whole","face","eyes") if outs[k] is not None]
656
+ Z = torch.stack(present, dim=1) # [B, num_views, dim]
657
+ fused = l2n((W.unsqueeze(-1) * Z).sum(dim=1)) # [B, dim]
658
+ return fused, outs, W
659
+
660
+ # ------------------------- Losses -------------------------
661
+ class ProxyAnchorLoss(nn.Module):
662
+ def __init__(self, num_classes, dim, alpha=16.0, margin=0.1, neg_weight=0.25):
663
+ super().__init__()
664
+ self.proxies = nn.Parameter(torch.randn(num_classes, dim))
665
+ nn.init.normal_(self.proxies, std=0.01)
666
+ self.alpha = float(alpha)
667
+ self.margin = float(margin)
668
+ self.neg_weight = float(neg_weight)
669
+ def forward(self, z, y):
670
+ with torch.amp.autocast('cuda', enabled=False):
671
+ z = F.normalize(z.float(), dim=-1)
672
+ P = F.normalize(self.proxies.float(), dim=-1)
673
+ sim = z @ P.t()
674
+ C = sim.size(1)
675
+ yOH = F.one_hot(y, num_classes=C).float()
676
+ pos_e = torch.clamp(-self.alpha * (sim - self.margin),
677
+ min=-60.0, max=60.0)
678
+ neg_e = torch.clamp( self.alpha * (sim + self.margin),
679
+ min=-60.0, max=60.0)
680
+ pos_term = torch.exp(pos_e) * yOH
681
+ neg_term = torch.exp(neg_e) * (1.0 - yOH)
682
+ pos_sum = pos_term.sum(0)
683
+ neg_sum = neg_term.sum(0)
684
+ num_pos = (yOH.sum(0) > 0)
685
+ L_pos = torch.log1p(pos_sum[num_pos]).sum() / (num_pos.sum().clamp_min(1.0))
686
+ L_neg = torch.log1p(neg_sum).sum() / C
687
+ return L_pos + self.neg_weight * L_neg
688
+
689
+ class SupConLoss(nn.Module):
690
+ def __init__(self, tau=0.07):
691
+ super().__init__()
692
+ self.tau = tau
693
+ def forward(self, feats, labels):
694
+ feats = l2n(feats)
695
+ sim = feats @ feats.t() / self.tau
696
+ logits = sim - torch.eye(sim.size(0), device=sim.device) * 1e9
697
+ pos_mask = (labels.unsqueeze(1) == labels.unsqueeze(0)) & \
698
+ (~torch.eye(len(labels), device=labels.device, dtype=torch.bool))
699
+ numer = (torch.exp(logits) * pos_mask).sum(1)
700
+ denom = torch.exp(logits).sum(1).clamp_min(1e-8)
701
+ valid = (pos_mask.sum(1) > 0)
702
+ loss = -torch.log((numer+1e-12) / denom)
703
+ return (loss[valid].mean() if valid.any() else torch.tensor(0.0, device=feats.device))
704
+
705
+ class MultiViewInfoNCE(nn.Module):
706
+ def __init__(self, tau=0.1):
707
+ super().__init__()
708
+ self.tau = tau
709
+ def forward(self, feats, gids):
710
+ feats = l2n(feats)
711
+ sim = feats @ feats.t() / self.tau
712
+ logits = sim - torch.eye(sim.size(0), device=sim.device) * 1e9
713
+ pos_mask = (gids.unsqueeze(1) == gids.unsqueeze(0)) & \
714
+ (~torch.eye(len(gids), device=gids.device, dtype=torch.bool))
715
+ numer = (torch.exp(logits) * pos_mask).sum(1)
716
+ denom = torch.exp(logits).sum(1).clamp_min(1e-8)
717
+ valid = (pos_mask.sum(1) > 0)
718
+ loss = -torch.log((numer+1e-12) / denom)
719
+ return (loss[valid].mean() if valid.any() else torch.tensor(0.0, device=feats.device))
720
+
721
+ # --------------------- Logging / checkpoints / schedulers -----------------
722
+ os.makedirs(cfg.out_dir, exist_ok=True)
723
+ LOG_TXT = os.path.join(cfg.out_dir, "train.log")
724
+ METRICS_CSV = os.path.join(cfg.out_dir, "metrics_epoch.csv")
725
+ if not os.path.exists(METRICS_CSV):
726
+ with open(METRICS_CSV, "w", encoding="utf-8") as f:
727
+ f.write("timestamp,stage,epoch,steps,P,K,train_loss,train_proxy,train_sup,train_mv,"
728
+ "val_proxy,proxy_top1,knn_r1,knn_r5,kmeans_acc,nmi,ari\n")
729
+
730
+ def wlog_global(msg, also_print=False):
731
+ ts_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
732
+ line = f"[{ts_str}] {msg}"
733
+ with open(LOG_TXT, "a", encoding="utf-8", buffering=1) as _logf:
734
+ _logf.write(line + "\n")
735
+ if also_print:
736
+ tqdm.write(line)
737
+
738
+ def write_epoch_metrics(stage_i, epoch_i, steps, P, K,
739
+ tr_mean, tr_p, tr_s, tr_m,
740
+ val_proxy, proxy_top1,
741
+ knn_r1, knn_r5,
742
+ kmeans_acc, nmi, ari):
743
+ ts_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
744
+ def fmt(x):
745
+ if x is None:
746
+ return "nan"
747
+ try:
748
+ if hasattr(x, "item"):
749
+ x = float(x.item())
750
+ else:
751
+ x = float(x)
752
+ except Exception:
753
+ return "nan"
754
+ if np.isnan(x) or np.isinf(x):
755
+ return "nan"
756
+ return f"{x:.6f}"
757
+ with open(METRICS_CSV, "a", encoding="utf-8") as fh:
758
+ fh.write(
759
+ f"{ts_str},{stage_i},{epoch_i},{steps},{P},{K},"
760
+ f"{fmt(tr_mean)},{fmt(tr_p)},{fmt(tr_s)},{fmt(tr_m)},"
761
+ f"{fmt(val_proxy)},{fmt(proxy_top1)},"
762
+ f"{fmt(knn_r1)},{fmt(knn_r5)},"
763
+ f"{fmt(kmeans_acc)},{fmt(nmi)},{fmt(ari)}\n"
764
+ )
765
+
766
+ def save_ckpt(path, model, proxy_loss, optim, sched, meta, is_main: bool):
767
+ if not is_main:
768
+ return
769
+ base_model = model.module if isinstance(model, nn.parallel.DistributedDataParallel) else model
770
+ torch.save({
771
+ "model": base_model.state_dict(),
772
+ "proxies": proxy_loss.state_dict(),
773
+ "optim": optim.state_dict() if optim else None,
774
+ "sched": sched.state_dict() if sched else None,
775
+ "meta": meta,
776
+ }, path)
777
+
778
+ def find_latest_checkpoint(out_dir):
779
+ paths = glob.glob(os.path.join(out_dir, "stage*_epoch*.pt"))
780
+ best, best_stage, best_epoch = None, -1, -1
781
+ for p in paths:
782
+ m = re.search(r"stage(\d+)_epoch(\d+)\.pt$", os.path.basename(p))
783
+ if not m:
784
+ continue
785
+ si, ep = int(m.group(1)), int(m.group(2))
786
+ if (si > best_stage) or (si == best_stage and ep > best_epoch):
787
+ best, best_stage, best_epoch = p, si, ep
788
+ return best, best_stage, best_epoch
789
+
790
+ def _pick_from_schedule(sched, default_val, ep):
791
+ if not sched:
792
+ return int(default_val)
793
+ if isinstance(sched, dict):
794
+ items = sorted([(int(k), int(v)) for k,v in sched.items()], key=lambda x: x[0])
795
+ else:
796
+ items = sorted([(int(k), int(v)) for k,v in sched], key=lambda x: x[0])
797
+ val = int(default_val)
798
+ for k,v in items:
799
+ if ep >= k:
800
+ val = int(v)
801
+ return int(val)
802
+
803
+ def resolve_epoch_PK(stage: dict, ep: int):
804
+ P = int(stage.get("P", cfg.P))
805
+ K = int(stage.get("K", cfg.K))
806
+ P = _pick_from_schedule(stage.get("P_schedule"), P, ep)
807
+ K = _pick_from_schedule(stage.get("K_schedule"), K, ep)
808
+ bs_sched = stage.get("bs_schedule")
809
+ if bs_sched:
810
+ bs = _pick_from_schedule(bs_sched, P*K, ep)
811
+ if bs % K != 0:
812
+ wlog_global(f"[batch] bs_schedule value {bs} not divisible by K={K}; rounding down to {bs//K*K}", also_print=True)
813
+ bs = (bs // K) * K
814
+ P = max(1, bs // K)
815
+ return int(P), int(K)
816
+
817
+ def estimate_steps_per_epoch(train_len: int, global_batch: int, max_steps: Optional[int]):
818
+ if max_steps is not None:
819
+ return int(max_steps)
820
+ return max(1, math.ceil(train_len / max(1, global_batch)))
821
+
822
+ def build_train_loader(ds: TriViewDataset, P: int, K: int):
823
+ bs = PKBatchSampler(ds, P, K)
824
+ dl = DataLoader(
825
+ ds,
826
+ batch_sampler=bs,
827
+ num_workers=cfg.workers,
828
+ pin_memory=True,
829
+ collate_fn=collate_triview,
830
+ persistent_workers=False,
831
+ prefetch_factor=2 if cfg.workers > 0 else None,
832
+ multiprocessing_context=MP_CTX,
833
+ )
834
+ return _track_dl(dl)
835
+
836
+ def make_cosine_with_warmup(optimizer, warmup_steps, total_steps):
837
+ def lr_lambda(step):
838
+ if step < warmup_steps:
839
+ return float(step) / max(1, warmup_steps)
840
+ rem = max(1, total_steps - warmup_steps)
841
+ progress = (step - warmup_steps) / rem
842
+ return 0.5 * (1.0 + math.cos(math.pi * progress))
843
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
844
+
845
+ # ------------------------------ DDP worker --------------------------------
846
+ def ddp_train_worker(rank: int, world_size: int):
847
+ torch.cuda.set_device(rank)
848
+ device = torch.device("cuda", rank)
849
+
850
+ os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
851
+ os.environ.setdefault("MASTER_PORT", "29500")
852
+ os.environ["RANK"] = str(rank)
853
+ os.environ["WORLD_SIZE"] = str(world_size)
854
+
855
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
856
+
857
+ seed_all(cfg.seed + rank)
858
+ is_main = (rank == 0)
859
+
860
+ # class count
861
+ artists_dir = os.path.join(cfg.data_root, cfg.folders['whole'])
862
+ num_classes_total = len([
863
+ d for d in os.listdir(artists_dir)
864
+ if os.path.isdir(os.path.join(artists_dir, d))
865
+ ])
866
+ if is_main:
867
+ wlog_global(f"[DDP] world_size={world_size}, num_classes_total={num_classes_total}", also_print=True)
868
+
869
+ # model & losses
870
+ base_model = TriViewStyleNet(
871
+ out_dim=cfg.embed_dim,
872
+ mix_p=cfg.mixstyle_p,
873
+ share_backbone=True,
874
+ ).to(device)
875
+ base_model = base_model.to(memory_format=torch.channels_last)
876
+
877
+ if cfg.use_compile and hasattr(torch, "compile"):
878
+ try:
879
+ base_model = torch.compile(base_model, mode="reduce-overhead", fullgraph=False)
880
+ except Exception:
881
+ pass
882
+
883
+ model = nn.parallel.DistributedDataParallel(
884
+ base_model,
885
+ device_ids=[rank],
886
+ output_device=rank,
887
+ find_unused_parameters=False,
888
+ )
889
+
890
+ proxy_loss = ProxyAnchorLoss(
891
+ num_classes=num_classes_total,
892
+ dim=cfg.embed_dim,
893
+ alpha=cfg.alpha_proxy,
894
+ margin=cfg.margin_proxy,
895
+ neg_weight=0.25,
896
+ ).to(device)
897
+
898
+ supcon = SupConLoss(tau=cfg.supcon_tau).to(device)
899
+ mv_infonce = MultiViewInfoNCE(tau=cfg.mv_tau).to(device)
900
+
901
+ # resume
902
+ resume_info = None
903
+ ckpt_path, ck_stage, ck_epoch = find_latest_checkpoint(cfg.out_dir)
904
+ if ckpt_path is not None:
905
+ ck = torch.load(ckpt_path, map_location="cpu")
906
+ try:
907
+ model.module.load_state_dict(ck["model"], strict=False)
908
+ except Exception as e:
909
+ if is_main:
910
+ wlog_global(f"[resume] WARNING: model state load failed: {e}", also_print=True)
911
+ try:
912
+ proxy_loss.load_state_dict(ck["proxies"])
913
+ except Exception as e:
914
+ if is_main:
915
+ wlog_global(f"[resume] WARNING: proxy state load failed: {e}", also_print=True)
916
+
917
+ meta = ck.get("meta", {})
918
+ last_stage = int(meta.get("stage", ck_stage or 1))
919
+ last_epoch = int(meta.get("epoch", ck_epoch or 0))
920
+ start_stage = last_stage
921
+ start_epoch = last_epoch + 1
922
+ if start_stage <= len(cfg.stages) and start_epoch > cfg.stages[start_stage-1]["epochs"]:
923
+ start_stage += 1
924
+ start_epoch = 1
925
+
926
+ resume_info = dict(
927
+ ckpt=ck,
928
+ path=ckpt_path,
929
+ last_stage=last_stage,
930
+ last_epoch=last_epoch,
931
+ start_stage=start_stage,
932
+ start_epoch=start_epoch,
933
+ )
934
+ if is_main:
935
+ wlog_global(
936
+ f"[resume] Found {ckpt_path} (stage {last_stage}, epoch {last_epoch}). "
937
+ f"Resuming at stage {start_stage}, epoch {start_epoch}.",
938
+ also_print=True,
939
+ )
940
+ else:
941
+ if is_main:
942
+ wlog_global("[resume] No checkpoint found; training from scratch.", also_print=True)
943
+
944
+ scaler = torch.amp.GradScaler('cuda', enabled=torch.cuda.is_available())
945
+ global_step = 0
946
+
947
+ proxy_lr_mult = 5.0
948
+ RAMP_EPOCHS = 3
949
+ WARMUP_EPOCHS = 1
950
+ VALIDATE_EVERY = 4 # N epoch마다 검증
951
+
952
+ from tqdm.auto import tqdm as tqdm_local
953
+
954
+ # Stage loop
955
+ for si, stage in enumerate(cfg.stages, 1):
956
+ if resume_info and si < resume_info["start_stage"]:
957
+ if is_main:
958
+ wlog_global(f"[resume] Skipping stage {si}; already completed.", also_print=True)
959
+ continue
960
+
961
+ # datasets per stage
962
+ T_w_tr, T_f_tr, T_e_tr = make_transforms(stage["sz_whole"], stage["sz_face"], stage["sz_eyes"])
963
+ T_w_val, T_f_val, T_e_val = make_val_transforms(stage["sz_whole"], stage["sz_face"], stage["sz_eyes"])
964
+
965
+ train_ds = TriViewDataset(cfg.data_root, cfg.folders, split="train",
966
+ T_whole=T_w_tr, T_face=T_f_tr, T_eyes=T_e_tr)
967
+ val_ds = TriViewDataset(cfg.data_root, cfg.folders, split="val",
968
+ T_whole=T_w_val, T_face=T_f_val, T_eyes=T_e_val)
969
+
970
+ # steps_per_epoch schedule (global batch 기준)
971
+ steps_list = []
972
+ for ep_tmp in range(1, stage["epochs"]+1):
973
+ P_tmp, K_tmp = resolve_epoch_PK(stage, ep_tmp)
974
+ global_batch = P_tmp * K_tmp * world_size
975
+ steps = estimate_steps_per_epoch(
976
+ len(train_ds),
977
+ global_batch,
978
+ cfg.max_steps_per_epoch,
979
+ )
980
+ steps_list.append(steps)
981
+ total_steps_stage = int(sum(steps_list))
982
+ warmup_steps = int(steps_list[0] * WARMUP_EPOCHS)
983
+
984
+ params = [
985
+ {"params": model.parameters(), "lr": stage["lr"]},
986
+ {"params": proxy_loss.parameters(), "lr": stage["lr"] * proxy_lr_mult},
987
+ ]
988
+ optim = torch.optim.AdamW(params, weight_decay=cfg.weight_decay)
989
+ sched = make_cosine_with_warmup(optim, warmup_steps=warmup_steps, total_steps=total_steps_stage)
990
+
991
+ start_ep = 1
992
+ if resume_info and si == resume_info["start_stage"]:
993
+ start_ep = resume_info["start_epoch"]
994
+ if resume_info["last_stage"] == si and start_ep > 1:
995
+ try:
996
+ if resume_info["ckpt"].get("optim") is not None:
997
+ optim.load_state_dict(resume_info["ckpt"]["optim"])
998
+ if resume_info["ckpt"].get("sched") is not None:
999
+ sched.load_state_dict(resume_info["ckpt"]["sched"])
1000
+ if is_main:
1001
+ wlog_global(f"[resume] Loaded optimizer/scheduler from {resume_info['path']}.", also_print=True)
1002
+ except Exception as e:
1003
+ if is_main:
1004
+ wlog_global(f"[resume] WARNING: could not load optimizer/scheduler state: {e}", also_print=True)
1005
+
1006
+ stage_msg = (f"\n=== [DDP] Stage {si}/{len(cfg.stages)} :: "
1007
+ f"wh/face/eyes={stage['sz_whole']}/{stage['sz_face']}/{stage['sz_eyes']} | "
1008
+ f"epochs={stage['epochs']} | lr={stage['lr']} | classes={num_classes_total} ===")
1009
+ if is_main:
1010
+ print(stage_msg)
1011
+ wlog_global(stage_msg)
1012
+
1013
+ # epoch loop
1014
+ for ep in range(start_ep, stage["epochs"]+1):
1015
+ P_e, K_e = resolve_epoch_PK(stage, ep)
1016
+ B_e = P_e * K_e # local batch
1017
+ train_dl = build_train_loader(train_ds, P_e, K_e)
1018
+
1019
+ steps_per_epoch = steps_list[ep-1]
1020
+ model.train()
1021
+ proxy_loss.train()
1022
+
1023
+ running = {"proxy":0.0, "supcon":0.0, "mv":0.0, "tot":0.0}
1024
+ ep_sum_tot = ep_sum_p = ep_sum_s = ep_sum_m = 0.0
1025
+ ramp = min(1.0, ep / RAMP_EPOCHS)
1026
+
1027
+ if is_main:
1028
+ tbar = tqdm_local(range(1, steps_per_epoch+1),
1029
+ desc=f"[train-DDP] stage{si} ep{ep} (P={P_e},K={K_e},B={B_e},rank={rank})",
1030
+ leave=True)
1031
+ else:
1032
+ tbar = range(1, steps_per_epoch+1)
1033
+
1034
+ train_iter = iter(train_dl)
1035
+
1036
+ for it in tbar:
1037
+ try:
1038
+ batch = next(train_iter)
1039
+ except Exception as e:
1040
+ if _should_fallback_workers(e) and cfg.workers > 0:
1041
+ if is_main:
1042
+ print("[mp] Worker pickling error detected. Rebuilding loaders with num_workers=0.")
1043
+ cfg.workers = 0
1044
+ train_dl = build_train_loader(train_ds, P_e, K_e)
1045
+ train_iter = iter(train_dl)
1046
+ batch = next(train_iter)
1047
+ else:
1048
+ raise
1049
+
1050
+ labels = batch["labels"].to(device, non_blocking=True)
1051
+ gids = batch["gids"].to(device, non_blocking=True)
1052
+ views = {
1053
+ k: (v.to(device, non_blocking=True).to(memory_format=torch.channels_last)
1054
+ if v is not None else None)
1055
+ for k,v in batch["views"].items()
1056
+ }
1057
+ masks = {k: v.to(device, non_blocking=True) for k,v in batch["masks"].items()}
1058
+
1059
+ with torch.amp.autocast('cuda', dtype=amp_dtype):
1060
+ z_fused, z_views_dict, W = model(views, masks)
1061
+
1062
+ Z_all, Y_all, G_all = [], [], []
1063
+ for vk in ("whole","face","eyes"):
1064
+ zk = z_views_dict.get(vk)
1065
+ if zk is None:
1066
+ continue
1067
+ mk = masks[vk]
1068
+ if mk.any():
1069
+ Z_all.append(zk[mk])
1070
+ Y_all.append(labels[mk])
1071
+ G_all.append(gids[mk])
1072
+ if len(Z_all) == 0:
1073
+ Z_all, Y_all, G_all = [z_fused], [labels], [gids]
1074
+ Z_all = torch.cat(Z_all, dim=0)
1075
+ Y_all = torch.cat(Y_all, dim=0)
1076
+ G_all = torch.cat(G_all, dim=0)
1077
+
1078
+ L_proxy = proxy_loss(z_fused, labels)
1079
+ L_sup = supcon(Z_all, Y_all)
1080
+ L_mv = mv_infonce(Z_all, G_all)
1081
+ L_total = L_proxy + (0.5 * ramp) * L_sup + (0.5 * ramp) * L_mv
1082
+
1083
+ optim.zero_grad(set_to_none=True)
1084
+ scaler.scale(L_total).backward()
1085
+ scaler.step(optim)
1086
+ scaler.update()
1087
+ sched.step()
1088
+ global_step += 1
1089
+
1090
+ running["proxy"] += L_proxy.item()
1091
+ running["supcon"] += L_sup.item()
1092
+ running["mv"] += L_mv.item()
1093
+ running["tot"] += L_total.item()
1094
+
1095
+ ep_sum_tot += L_total.item()
1096
+ ep_sum_p += L_proxy.item()
1097
+ ep_sum_s += L_sup.item()
1098
+ ep_sum_m += L_mv.item()
1099
+
1100
+ if is_main and (it % cfg.print_every == 0 or it == steps_per_epoch):
1101
+ denom = min(cfg.print_every, it % cfg.print_every or cfg.print_every)
1102
+ tbar.set_postfix({
1103
+ "L": f"{running['tot']/denom:.3f}",
1104
+ "proxy": f"{running['proxy']/denom:.3f}",
1105
+ "sup": f"{running['supcon']/denom:.3f}",
1106
+ "mv": f"{running['mv']/denom:.3f}",
1107
+ "lr": f"{optim.param_groups[0]['lr']:.2e}",
1108
+ })
1109
+ msg = (f"stage{si} ep{ep:02d} it{it:05d}/{steps_per_epoch} | "
1110
+ f"P={P_e} K={K_e} B={B_e} | "
1111
+ f"L={running['tot']/denom:.3f} "
1112
+ f"(proxy={running['proxy']/denom:.3f}, "
1113
+ f"sup={running['supcon']/denom:.3f}, "
1114
+ f"mv={running['mv']/denom:.3f}) | "
1115
+ f"lr={optim.param_groups[0]['lr']:.2e}")
1116
+ wlog_global(msg)
1117
+ running = {k:0.0 for k in running}
1118
+
1119
+ # ===== 검증 (proxy loss + proxy Top1만) =====
1120
+ proxy_top1 = float("nan")
1121
+ kmeans_acc = float("nan") # 사용 안 하지만 CSV 포맷 때문에 남겨둠
1122
+ nmi = float("nan")
1123
+ ari = float("nan")
1124
+ knn_r1 = float("nan")
1125
+ knn_r5 = float("nan")
1126
+ val_proxy_mean = float("nan")
1127
+
1128
+ do_val = (VALIDATE_EVERY <= 0) or (ep % VALIDATE_EVERY == 0) or (ep == stage["epochs"])
1129
+
1130
+ if do_val:
1131
+ from torch.utils.data.distributed import DistributedSampler
1132
+
1133
+ val_sampler = DistributedSampler(
1134
+ val_ds,
1135
+ num_replicas=world_size,
1136
+ rank=rank,
1137
+ shuffle=False,
1138
+ drop_last=False,
1139
+ )
1140
+ val_sampler.set_epoch(ep)
1141
+
1142
+ val_dl_ddp = DataLoader(
1143
+ val_ds,
1144
+ batch_size=B_e,
1145
+ sampler=val_sampler,
1146
+ num_workers=min(8, cfg.workers),
1147
+ pin_memory=True,
1148
+ collate_fn=collate_triview,
1149
+ persistent_workers=False,
1150
+ multiprocessing_context=MP_CTX,
1151
+ )
1152
+
1153
+ model.eval()
1154
+ proxy_loss.eval()
1155
+
1156
+ local_loss_sum = 0.0
1157
+ local_loss_cnt = 0.0
1158
+ local_correct = 0.0
1159
+ local_total = 0.0
1160
+
1161
+ with torch.no_grad():
1162
+ Pn = F.normalize(proxy_loss.proxies.detach(), dim=1).to(device)
1163
+
1164
+ with torch.no_grad(), torch.amp.autocast('cuda', dtype=amp_dtype):
1165
+ for batch in val_dl_ddp:
1166
+ labels = batch["labels"].to(device, non_blocking=True)
1167
+ views = {
1168
+ k: (v.to(device).to(memory_format=torch.channels_last) if v is not None else None)
1169
+ for k, v in batch["views"].items()
1170
+ }
1171
+ masks = {k: v.to(device, non_blocking=True) for k, v in batch["masks"].items()}
1172
+
1173
+ z_fused, _, _ = model(views, masks)
1174
+ L = proxy_loss(z_fused, labels)
1175
+
1176
+ z_norm = F.normalize(z_fused, dim=1)
1177
+ logits = z_norm @ Pn.t()
1178
+ pred = logits.argmax(dim=1)
1179
+ correct = (pred == labels).float().sum().item()
1180
+
1181
+ bs = float(labels.size(0))
1182
+ local_loss_sum += L.item()
1183
+ local_loss_cnt += 1.0
1184
+ local_correct += correct
1185
+ local_total += bs
1186
+
1187
+ t = torch.tensor(
1188
+ [local_loss_sum, local_loss_cnt, local_correct, local_total],
1189
+ device=device,
1190
+ )
1191
+ dist.all_reduce(t, op=dist.ReduceOp.SUM)
1192
+ total_loss_sum = float(t[0].item())
1193
+ total_loss_cnt = max(1.0, float(t[1].item()))
1194
+ total_correct = float(t[2].item())
1195
+ total_total = max(1.0, float(t[3].item()))
1196
+
1197
+ val_proxy_mean = total_loss_sum / total_loss_cnt
1198
+ proxy_top1 = total_correct / total_total
1199
+
1200
+ if is_main:
1201
+ print(f"[val] ep{ep:02d} proxy-loss ~ {val_proxy_mean:.3f}, Top1={proxy_top1:.4f}")
1202
+ wlog_global(f"[val] ep{ep:02d} proxy-loss ~ {val_proxy_mean:.3f}, Top1={proxy_top1:.4f}")
1203
+
1204
+ dist.barrier()
1205
+
1206
+ _close_dl(val_dl_ddp)
1207
+ del val_dl_ddp
1208
+ gc.collect()
1209
+ time.sleep(0.05)
1210
+
1211
+ # ----- Epoch metrics & checkpoint (rank0) -----
1212
+ train_mean = ep_sum_tot / steps_per_epoch
1213
+ train_p = ep_sum_p / steps_per_epoch
1214
+ train_s = ep_sum_s / steps_per_epoch
1215
+ train_m = ep_sum_m / steps_per_epoch
1216
+
1217
+ write_epoch_metrics(
1218
+ si, ep, steps_per_epoch, P_e, K_e,
1219
+ train_mean, train_p, train_s, train_m,
1220
+ val_proxy_mean, proxy_top1,
1221
+ knn_r1, knn_r5,
1222
+ kmeans_acc, nmi, ari,
1223
+ )
1224
+
1225
+ ck = os.path.join(cfg.out_dir, f"stage{si}_epoch{ep}.pt")
1226
+ save_ckpt(
1227
+ ck, model, proxy_loss, optim, sched,
1228
+ meta=dict(
1229
+ stage=si, epoch=ep,
1230
+ P=P_e, K=K_e, steps=steps_per_epoch,
1231
+ val_every=VALIDATE_EVERY,
1232
+ proxy_top1=proxy_top1,
1233
+ knn_r1=knn_r1, knn_r5=knn_r5,
1234
+ ),
1235
+ is_main=is_main,
1236
+ )
1237
+ if is_main:
1238
+ print(f"Saved: {ck}")
1239
+ wlog_global(f"Saved: {ck}")
1240
+
1241
+ _close_dl(train_dl)
1242
+ del train_dl
1243
+ gc.collect()
1244
+ time.sleep(0.1)
1245
+
1246
+ dist.destroy_process_group()
1247
+ if is_main:
1248
+ print("\n[DDP] Training finished or paused. Checkpoints in:", cfg.out_dir)
1249
+ print("Logs:", LOG_TXT, " | CSV:", METRICS_CSV)
1250
+ print("Tip: Re-run this script to RESUME (DDP).")
1251
+
1252
+ # ------------------------------ entry point --------------------------------
1253
+ def run_ddp_training():
1254
+ if not torch.cuda.is_available():
1255
+ print("CUDA not available; DDP training requires GPU.")
1256
+ return
1257
+
1258
+ world_size = torch.cuda.device_count()
1259
+ print(f"[DDP] Launching training on {world_size} GPUs...")
1260
+ mp.spawn(
1261
+ ddp_train_worker,
1262
+ args=(world_size,),
1263
+ nprocs=world_size,
1264
+ join=True,
1265
+ )
1266
+
1267
+ if __name__ == "__main__":
1268
+ run_ddp_training()