| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from mld.models.operator.attention import SkipTransformerEncoder, TransformerEncoderLayer | |
| from mld.models.operator.position_encoding import build_position_encoding | |
| class MldTrajEncoder(nn.Module): | |
| def __init__(self, | |
| nfeats: int, | |
| latent_dim: list = [1, 256], | |
| hidden_dim: Optional[int] = None, | |
| force_post_proj: bool = False, | |
| ff_size: int = 1024, | |
| num_layers: int = 9, | |
| num_heads: int = 4, | |
| dropout: float = 0.1, | |
| normalize_before: bool = False, | |
| norm_eps: float = 1e-5, | |
| activation: str = "gelu", | |
| norm_post: bool = True, | |
| activation_post: Optional[str] = None, | |
| position_embedding: str = "learned") -> None: | |
| super(MldTrajEncoder, self).__init__() | |
| self.latent_size = latent_dim[0] | |
| self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim | |
| add_post_proj = force_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1]) | |
| self.latent_proj = nn.Linear(self.latent_dim, latent_dim[-1]) if add_post_proj else nn.Identity() | |
| self.skel_embedding = nn.Linear(nfeats * 3, self.latent_dim) | |
| self.query_pos_encoder = build_position_encoding( | |
| self.latent_dim, position_embedding=position_embedding) | |
| encoder_layer = TransformerEncoderLayer( | |
| self.latent_dim, | |
| num_heads, | |
| ff_size, | |
| dropout, | |
| activation, | |
| normalize_before, | |
| norm_eps | |
| ) | |
| encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None | |
| self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post) | |
| self.global_motion_token = nn.Parameter(torch.randn(self.latent_size, self.latent_dim)) | |
| def forward(self, features: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: | |
| bs, nframes, nfeats = features.shape | |
| x = self.skel_embedding(features) | |
| x = x.permute(1, 0, 2) | |
| dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1)) | |
| dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device) | |
| aug_mask = torch.cat((dist_masks, mask), 1) | |
| xseq = torch.cat((dist, x), 0) | |
| xseq = self.query_pos_encoder(xseq) | |
| global_token = self.encoder(xseq, src_key_padding_mask=~aug_mask)[0][:dist.shape[0]] | |
| global_token = self.latent_proj(global_token) | |
| global_token = global_token.permute(1, 0, 2) | |
| return global_token | |