| | import torch.nn as nn |
| | import torch |
| | import numpy as np |
| |
|
| | class PointEmbed(nn.Module): |
| | def __init__(self, hidden_dim=48): |
| | super().__init__() |
| |
|
| | assert hidden_dim % 6 == 0 |
| |
|
| | self.embedding_dim = hidden_dim |
| | e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi |
| | e = torch.stack([ |
| | torch.cat([e, torch.zeros(self.embedding_dim // 6), |
| | torch.zeros(self.embedding_dim // 6)]), |
| | torch.cat([torch.zeros(self.embedding_dim // 6), e, |
| | torch.zeros(self.embedding_dim // 6)]), |
| | torch.cat([torch.zeros(self.embedding_dim // 6), |
| | torch.zeros(self.embedding_dim // 6), e]), |
| | ]) |
| | self.register_buffer('basis', e) |
| |
|
| |
|
| | @staticmethod |
| | def embed(input, basis): |
| | projections = torch.einsum( |
| | 'bnd,de->bne', input, basis) |
| | embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) |
| | return embeddings |
| |
|
| | def forward(self, input): |
| | |
| | embed = self.embed(input, self.basis) |
| | return embed |
| |
|