| from typing import Set | |
| import spconv | |
| if float(spconv.__version__[2:]) >= 2.2: | |
| spconv.constants.SPCONV_USE_DIRECT_TABLE = False | |
| try: | |
| import spconv.pytorch as spconv | |
| except: | |
| import spconv as spconv | |
| import torch.nn as nn | |
| def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]: | |
| """ | |
| Finds all spconv keys that need to have weight's transposed | |
| """ | |
| found_keys: Set[str] = set() | |
| for name, child in model.named_children(): | |
| new_prefix = f"{prefix}.{name}" if prefix != "" else name | |
| if isinstance(child, spconv.conv.SparseConvolution): | |
| new_prefix = f"{new_prefix}.weight" | |
| found_keys.add(new_prefix) | |
| found_keys.update(find_all_spconv_keys(child, prefix=new_prefix)) | |
| return found_keys | |
| def replace_feature(out, new_features): | |
| if "replace_feature" in out.__dir__(): | |
| # spconv 2.x behaviour | |
| return out.replace_feature(new_features) | |
| else: | |
| out.features = new_features | |
| return out | |