Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| class ShapeAttrEmbedding(nn.Module): | |
| def __init__(self, dim, out_dim, cls_num_list): | |
| super(ShapeAttrEmbedding, self).__init__() | |
| for idx, cls_num in enumerate(cls_num_list): | |
| setattr( | |
| self, f'attr_{idx}', | |
| nn.Sequential( | |
| nn.Linear(cls_num, dim), nn.LeakyReLU(), | |
| nn.Linear(dim, dim))) | |
| self.cls_num_list = cls_num_list | |
| self.attr_num = len(cls_num_list) | |
| self.fusion = nn.Sequential( | |
| nn.Linear(dim * self.attr_num, out_dim), nn.LeakyReLU(), | |
| nn.Linear(out_dim, out_dim)) | |
| def forward(self, attr): | |
| attr_embedding_list = [] | |
| for idx in range(self.attr_num): | |
| attr_embed_fc = getattr(self, f'attr_{idx}') | |
| attr_embedding_list.append( | |
| attr_embed_fc( | |
| F.one_hot( | |
| attr[:, idx], | |
| num_classes=self.cls_num_list[idx]).to(torch.float32))) | |
| attr_embedding = torch.cat(attr_embedding_list, dim=1) | |
| attr_embedding = self.fusion(attr_embedding) | |
| return attr_embedding | |