| from functools import partial | |
| import torch | |
| from torch import nn | |
| class Swish(nn.Module): | |
| def __init__(self): | |
| super(Swish, self).__init__() | |
| def forward(self, x): | |
| return x * torch.sigmoid(x) | |
| def linear(): | |
| return nn.Identity() | |
| def relu(): | |
| return nn.ReLU() | |
| def prelu(): | |
| return nn.PReLU() | |
| def leaky_relu(): | |
| return nn.LeakyReLU() | |
| def sigmoid(): | |
| return nn.Sigmoid() | |
| def softmax(dim=None): | |
| return nn.Softmax(dim=dim) | |
| def tanh(): | |
| return nn.Tanh() | |
| def gelu(): | |
| return nn.GELU() | |
| def swish(): | |
| return Swish() | |
| def register_activation(custom_act): | |
| """Register a custom activation, gettable with `activation.get`. | |
| Args: | |
| custom_act: Custom activation function to register. | |
| """ | |
| if custom_act.__name__ in globals().keys() or custom_act.__name__.lower() in globals().keys(): | |
| raise ValueError(f"Activation {custom_act.__name__} already exists. Choose another name.") | |
| globals().update({custom_act.__name__: custom_act}) | |
| def get(identifier): | |
| """Returns an activation function from a string. Returns its input if it | |
| is callable (already an activation for example). | |
| Args: | |
| identifier (str or Callable or None): the activation identifier. | |
| Returns: | |
| :class:`nn.Module` or None | |
| """ | |
| if identifier is None: | |
| return None | |
| elif callable(identifier): | |
| return identifier | |
| elif isinstance(identifier, str): | |
| cls = globals().get(identifier) | |
| if cls is None: | |
| raise ValueError("Could not interpret activation identifier: " + str(identifier)) | |
| return cls | |
| else: | |
| raise ValueError("Could not interpret activation identifier: " + str(identifier)) | |