| import os | |
| import torch | |
| from torch import nn | |
| from io import BytesIO | |
| from Crypto.Cipher import AES | |
| from Crypto.Util.Padding import unpad | |
| def decrypt_model(configs, input_path): | |
| with open(input_path, "rb") as f: | |
| data = f.read() | |
| with open( | |
| os.path.join(configs["binary_path"], "decrypt.bin"), | |
| "rb" | |
| ) as f: | |
| key = f.read() | |
| return BytesIO( | |
| unpad( | |
| AES.new( | |
| key, | |
| AES.MODE_CBC, | |
| data[:16] | |
| ).decrypt(data[16:]), | |
| AES.block_size | |
| ) | |
| ).read() | |
| def calc_same_padding(kernel_size): | |
| pad = kernel_size // 2 | |
| return (pad, pad - (kernel_size + 1) % 2) | |
| def torch_interp(x, xp, fp): | |
| sort_idx = xp.argsort() | |
| xp = xp[sort_idx] | |
| fp = fp[sort_idx] | |
| right_idxs = torch.searchsorted(xp, x).clamp(max=len(xp) - 1) | |
| left_idxs = (right_idxs - 1).clamp(min=0) | |
| x_left = xp[left_idxs] | |
| y_left = fp[left_idxs] | |
| interp_vals = y_left + ((x - x_left) * (fp[right_idxs] - y_left) / (xp[right_idxs] - x_left)) | |
| interp_vals[x < xp[0]] = fp[0] | |
| interp_vals[x > xp[-1]] = fp[-1] | |
| return interp_vals | |
| def batch_interp_with_replacement_detach(uv, f0): | |
| result = f0.clone() | |
| for i in range(uv.shape[0]): | |
| interp_vals = torch_interp( | |
| torch.where(uv[i])[-1], | |
| torch.where(~uv[i])[-1], | |
| f0[i][~uv[i]] | |
| ).detach() | |
| result[i][uv[i]] = interp_vals | |
| return result | |
| class DotDict(dict): | |
| def __getattr__(*args): | |
| val = dict.get(*args) | |
| return DotDict(val) if type(val) is dict else val | |
| __setattr__ = dict.__setitem__ | |
| __delattr__ = dict.__delitem__ | |
| class Swish(nn.Module): | |
| def forward(self, x): | |
| return x * x.sigmoid() | |
| class Transpose(nn.Module): | |
| def __init__(self, dims): | |
| super().__init__() | |
| assert len(dims) == 2, "dims == 2" | |
| self.dims = dims | |
| def forward(self, x): | |
| return x.transpose(*self.dims) | |
| class GLU(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| out, gate = x.chunk(2, dim=self.dim) | |
| return out * gate.sigmoid() |