import os import torch from torch import nn from io import BytesIO from Crypto.Cipher import AES from Crypto.Util.Padding import unpad def decrypt_model(configs, input_path): with open(input_path, "rb") as f: data = f.read() with open( os.path.join(configs["binary_path"], "decrypt.bin"), "rb" ) as f: key = f.read() return BytesIO( unpad( AES.new( key, AES.MODE_CBC, data[:16] ).decrypt(data[16:]), AES.block_size ) ).read() def calc_same_padding(kernel_size): pad = kernel_size // 2 return (pad, pad - (kernel_size + 1) % 2) def torch_interp(x, xp, fp): sort_idx = xp.argsort() xp = xp[sort_idx] fp = fp[sort_idx] right_idxs = torch.searchsorted(xp, x).clamp(max=len(xp) - 1) left_idxs = (right_idxs - 1).clamp(min=0) x_left = xp[left_idxs] y_left = fp[left_idxs] interp_vals = y_left + ((x - x_left) * (fp[right_idxs] - y_left) / (xp[right_idxs] - x_left)) interp_vals[x < xp[0]] = fp[0] interp_vals[x > xp[-1]] = fp[-1] return interp_vals def batch_interp_with_replacement_detach(uv, f0): result = f0.clone() for i in range(uv.shape[0]): interp_vals = torch_interp( torch.where(uv[i])[-1], torch.where(~uv[i])[-1], f0[i][~uv[i]] ).detach() result[i][uv[i]] = interp_vals return result class DotDict(dict): def __getattr__(*args): val = dict.get(*args) return DotDict(val) if type(val) is dict else val __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ class Swish(nn.Module): def forward(self, x): return x * x.sigmoid() class Transpose(nn.Module): def __init__(self, dims): super().__init__() assert len(dims) == 2, "dims == 2" self.dims = dims def forward(self, x): return x.transpose(*self.dims) class GLU(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): out, gate = x.chunk(2, dim=self.dim) return out * gate.sigmoid()