TiM / tim /models /utils /funcs.py
blanchon's picture
Update
3ed0796
import torch
import torch.nn as nn
from torch import Tensor
from typing import List, Tuple
from itertools import chain
def expand_t_like_x(t, x):
"""Function to reshape time t to broadcastable dimension of x
Args:
t: [batch_dim,], time vector
x: [batch_dim,...], data point
"""
dims = [1] * (len(x.size()) - 1)
t = t.view(t.size(0), *dims)
return t
def build_mlp(hidden_size, projector_dim, z_dim):
return nn.Sequential(
nn.Linear(hidden_size, projector_dim),
nn.SiLU(),
nn.Linear(projector_dim, projector_dim),
nn.SiLU(),
nn.Linear(projector_dim, z_dim),
)
def modulate(x, shift, scale):
return x * (1 + scale) + shift
def get_parameter_dtype(parameter: torch.nn.Module):
try:
params = tuple(parameter.parameters())
if len(params) > 0:
return params[0].dtype
buffers = tuple(parameter.buffers())
if len(buffers) > 0:
return buffers[0].dtype
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype