Spaces:
Build error
Build error
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor | |
| from random import random | |
| from einops.layers.torch import Rearrange | |
| from einops import rearrange, repeat, reduce, pack, unpack | |
| def exists(val): | |
| return val is not None | |
| def identity(t): | |
| return t | |
| def default(val, d): | |
| return val if exists(val) else d | |
| def divisible_by(num, den): | |
| return (num % den) == 0 | |
| def is_odd(n): | |
| return not divisible_by(n, 2) | |
| def coin_flip(): | |
| return random() < 0.5 | |
| def pack_one(t, pattern): | |
| return pack([t], pattern) | |
| def unpack_one(t, ps, pattern): | |
| return unpack(t, ps, pattern)[0] | |
| # tensor helpers | |
| def prob_mask_like(shape, prob, device): | |
| if prob == 1: | |
| return torch.ones(shape, device = device, dtype = torch.bool) | |
| elif prob == 0: | |
| return torch.zeros(shape, device = device, dtype = torch.bool) | |
| else: | |
| return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob | |
| def reduce_masks_with_and(*masks): | |
| masks = [*filter(exists, masks)] | |
| if len(masks) == 0: | |
| return None | |
| mask, *rest_masks = masks | |
| for rest_mask in rest_masks: | |
| mask = mask & rest_mask | |
| return mask | |
| def interpolate_1d(t, length, mode = 'bilinear'): | |
| " pytorch does not offer interpolation 1d, so hack by converting to 2d " | |
| dtype = t.dtype | |
| t = t.float() | |
| implicit_one_channel = t.ndim == 2 | |
| if implicit_one_channel: | |
| t = rearrange(t, 'b n -> b 1 n') | |
| t = rearrange(t, 'b d n -> b d n 1') | |
| t = F.interpolate(t, (length, 1), mode = mode) | |
| t = rearrange(t, 'b d n 1 -> b d n') | |
| if implicit_one_channel: | |
| t = rearrange(t, 'b 1 n -> b n') | |
| t = t.to(dtype) | |
| return t | |
| def curtail_or_pad(t, target_length): | |
| length = t.shape[-2] | |
| if length > target_length: | |
| t = t[..., :target_length, :] | |
| elif length < target_length: | |
| t = F.pad(t, (0, 0, 0, target_length - length), value = 0.) | |
| return t | |
| # mask construction helpers | |
| def mask_from_start_end_indices( | |
| seq_len: int, | |
| start: Tensor, | |
| end: Tensor | |
| ): | |
| assert start.shape == end.shape | |
| device = start.device | |
| seq = torch.arange(seq_len, device = device, dtype = torch.long) | |
| seq = seq.reshape(*((-1,) * start.ndim), seq_len) | |
| seq = seq.expand(*start.shape, seq_len) | |
| mask = seq >= start[..., None].long() | |
| mask &= seq < end[..., None].long() | |
| return mask | |
| def mask_from_frac_lengths( | |
| seq_len: int, | |
| frac_lengths: Tensor | |
| ): | |
| device = frac_lengths | |
| lengths = (frac_lengths * seq_len).long() | |
| max_start = seq_len - lengths | |
| rand = torch.zeros_like(frac_lengths).float().uniform_(0, 1) | |
| start = (max_start * rand).clamp(min = 0) | |
| end = start + lengths | |
| return mask_from_start_end_indices(seq_len, start, end) | |
| # sinusoidal positions |