Spaces:
Sleeping
Sleeping
| # Models | |
| # from .conv_tasnet import ConvTasNet | |
| # from .dccrnet import DCCRNet | |
| # from .dcunet import DCUNet | |
| # from .dprnn_tasnet import DPRNNTasNet | |
| # from .sudormrf import SuDORMRFImprovedNet, SuDORMRFNet | |
| from .dptnet import DPTNet | |
| # from .lstm_tasnet import LSTMTasNet | |
| # from .demask import DeMask | |
| # Sharing-related | |
| # from .publisher import save_publishable, upload_publishable | |
| __all__ = [ | |
| # "ConvTasNet", | |
| # "DPRNNTasNet", | |
| # "SuDORMRFImprovedNet", | |
| # "SuDORMRFNet", | |
| "DPTNet", | |
| # "LSTMTasNet", | |
| # "DeMask", | |
| # "DCUNet", | |
| # "DCCRNet", | |
| # "save_publishable", | |
| # "upload_publishable", | |
| ] | |
| def register_model(custom_model): | |
| """Register a custom model, gettable with `models.get`. | |
| Args: | |
| custom_model: Custom model to register. | |
| """ | |
| if ( | |
| custom_model.__name__ in globals().keys() | |
| or custom_model.__name__.lower() in globals().keys() | |
| ): | |
| raise ValueError(f"Model {custom_model.__name__} already exists. Choose another name.") | |
| globals().update({custom_model.__name__: custom_model}) | |
| def get(identifier): | |
| """Returns an model class from a string (case-insensitive). | |
| Args: | |
| identifier (str): the model name. | |
| Returns: | |
| :class:`torch.nn.Module` | |
| """ | |
| if isinstance(identifier, str): | |
| to_get = {k.lower(): v for k, v in globals().items()} | |
| cls = to_get.get(identifier.lower()) | |
| if cls is None: | |
| raise ValueError(f"Could not interpret model name : {str(identifier)}") | |
| return cls | |
| raise ValueError(f"Could not interpret model name : {str(identifier)}") | |