Spaces:
Runtime error
Runtime error
| """Miscilaneous functions. | |
| """ | |
| import torch | |
| def log_sum_exp(x, axis=None): | |
| """Log sum exp function | |
| Args: | |
| x: Input. | |
| axis: Axis over which to perform sum. | |
| Returns: | |
| torch.Tensor: log sum exp | |
| """ | |
| x_max = torch.max(x, axis)[0] | |
| y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max | |
| return y | |
| def random_permute(X): | |
| """Randomly permutes a tensor. | |
| Args: | |
| X: Input tensor. | |
| Returns: | |
| torch.Tensor | |
| """ | |
| X = X.transpose(1, 2) | |
| b = torch.rand((X.size(0), X.size(1))).cuda() | |
| idx = b.sort(0)[1] | |
| adx = torch.range(0, X.size(1) - 1).long() | |
| X = X[idx, adx[None, :]].transpose(1, 2) | |
| return X | |