|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from torch import Tensor |
|
|
from typing import List, Tuple |
|
|
from itertools import chain |
|
|
|
|
|
|
|
|
|
|
|
def expand_t_like_x(t, x): |
|
|
"""Function to reshape time t to broadcastable dimension of x |
|
|
Args: |
|
|
t: [batch_dim,], time vector |
|
|
x: [batch_dim,...], data point |
|
|
""" |
|
|
dims = [1] * (len(x.size()) - 1) |
|
|
t = t.view(t.size(0), *dims) |
|
|
return t |
|
|
|
|
|
|
|
|
def build_mlp(hidden_size, projector_dim, z_dim): |
|
|
return nn.Sequential( |
|
|
nn.Linear(hidden_size, projector_dim), |
|
|
nn.SiLU(), |
|
|
nn.Linear(projector_dim, projector_dim), |
|
|
nn.SiLU(), |
|
|
nn.Linear(projector_dim, z_dim), |
|
|
) |
|
|
|
|
|
def modulate(x, shift, scale): |
|
|
return x * (1 + scale) + shift |
|
|
|
|
|
|
|
|
def get_parameter_dtype(parameter: torch.nn.Module): |
|
|
try: |
|
|
params = tuple(parameter.parameters()) |
|
|
if len(params) > 0: |
|
|
return params[0].dtype |
|
|
|
|
|
buffers = tuple(parameter.buffers()) |
|
|
if len(buffers) > 0: |
|
|
return buffers[0].dtype |
|
|
|
|
|
except StopIteration: |
|
|
|
|
|
|
|
|
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: |
|
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] |
|
|
return tuples |
|
|
|
|
|
gen = parameter._named_members(get_members_fn=find_tensor_attributes) |
|
|
first_tuple = next(gen) |
|
|
return first_tuple[1].dtype |