| | import yaml |
| | import string |
| | import secrets |
| | import os |
| |
|
| | import torch |
| | import wandb |
| | from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint |
| | from torchdyn.core import NeuralODE |
| |
|
| | import torch |
| |
|
| | @torch.no_grad() |
| | def gather_local_starts(x0s, X0_pool, N, k=64): |
| | |
| | B, G = x0s.shape |
| | d2 = torch.cdist(x0s, X0_pool).pow(2) |
| | knn_idx = d2.topk(k=min(k, X0_pool.size(0)), largest=False).indices |
| | x0_clusters = [] |
| | for b in range(B): |
| | choices = knn_idx[b] |
| | pick = choices[torch.randperm(choices.numel(), device=choices.device)[:N]] |
| | x0_clusters.append(X0_pool[pick]) |
| | return torch.stack(x0_clusters, dim=0) |
| |
|
| | @torch.no_grad() |
| | def make_aligned_clusters(ot_sampler, x0s, x1s, N, replace=True, k_local=128): |
| | |
| | device, dtype = x0s.device, x0s.dtype |
| | |
| | B, G = x0s.shape |
| | M = x1s.shape[0] |
| | |
| | x0_clusters = gather_local_starts(x0s, x0s, N, k=k_local).to(device=device, dtype=dtype) |
| | x1_clusters = torch.empty((B, N, G), device=device, dtype=dtype) |
| | idx1 = torch.empty((B, N), device=device, dtype=torch.long) |
| |
|
| | |
| | P = None |
| | if hasattr(ot_sampler, "coupling"): |
| | P = ot_sampler.coupling(x0s, x1s) |
| | elif hasattr(ot_sampler, "plan"): |
| | P = ot_sampler.plan(x0s, x1s) |
| | |
| |
|
| | for b in range(B): |
| | x0_b = x0s[b:b+1] |
| |
|
| | if P is not None: |
| | |
| | probs = P[b].clamp_min(0) |
| | probs = probs / probs.sum().clamp_min(1e-12) |
| | if replace: |
| | j = torch.multinomial(probs, num_samples=N, replacement=True) |
| | else: |
| | k = min(N, (probs > 0).sum().item()) |
| | j = torch.multinomial(probs, num_samples=k, replacement=False) |
| | if k < N: |
| | j = torch.cat([j, j[-1:].expand(N-k)], dim=0) |
| | x1_match = x1s[j] |
| | else: |
| | |
| | |
| | got = False |
| | if hasattr(ot_sampler, "sample_plan"): |
| | try: |
| | |
| | x0_rep, x1_match = ot_sampler.sample_plan( |
| | x0_b, x1s, replace=replace, n_pairs=N |
| | ) |
| | |
| | x1_match = x1_match.view(N, G) |
| | got = True |
| | except TypeError: |
| | pass |
| | if not got: |
| | |
| | xs, ys, js = [], [], [] |
| | for _ in range(N): |
| | x0_rep, x1_one = ot_sampler.sample_plan(x0_b, x1s, replace=replace) |
| | |
| | j_hat = torch.cdist(x1_one.view(1, -1), x1s).argmin() |
| | xs.append(x0_rep.view(1, G)) |
| | ys.append(x1_one.view(1, G)) |
| | js.append(j_hat.view(1)) |
| | x1_match = torch.cat(ys, dim=0) |
| | j = torch.cat(js, dim=0) |
| |
|
| | |
| | |
| | x1_clusters[b] = x1_match |
| | idx1[b] = j |
| |
|
| | return x0_clusters, x1_clusters, idx1 |
| |
|
| |
|
| | def load_config(path): |
| | with open(path, "r") as file: |
| | config = yaml.safe_load(file) |
| | return config |
| |
|
| |
|
| | def merge_config(args, config_updates): |
| | for key, value in config_updates.items(): |
| | if not hasattr(args, key): |
| | raise ValueError( |
| | f"Unknown configuration parameter '{key}' found in the config file." |
| | ) |
| | setattr(args, key, value) |
| | return args |
| |
|
| |
|
| | def generate_group_string(length=16): |
| | alphabet = string.ascii_letters + string.digits |
| | return "".join(secrets.choice(alphabet) for _ in range(length)) |