Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| from typing import List, Dict | |
| import torch as th | |
| import torch.nn as nn | |
| from pytorch3d.renderer import ( | |
| RasterizationSettings, | |
| MeshRasterizer, | |
| ) | |
| from pytorch3d.structures import Meshes | |
| from pytorch3d.renderer.mesh.textures import TexturesUV | |
| from pytorch3d.utils import cameras_from_opencv_projection | |
| class RenderLayer(nn.Module): | |
| def __init__(self, h, w, vi, vt, vti, flip_uvs=False): | |
| super().__init__() | |
| self.register_buffer("vi", vi, persistent=False) | |
| self.register_buffer("vt", vt, persistent=False) | |
| self.register_buffer("vti", vti, persistent=False) | |
| raster_settings = RasterizationSettings(image_size=(h, w)) | |
| self.rasterizer = MeshRasterizer(raster_settings=raster_settings) | |
| self.flip_uvs = flip_uvs | |
| image_size = th.as_tensor([h, w], dtype=th.int32) | |
| self.register_buffer("image_size", image_size) | |
| def forward(self, verts: th.Tensor, tex: th.Tensor, K: th.Tensor, Rt: th.Tensor, background: th.Tensor = None, output_filters: List[str] = None): | |
| assert output_filters is None | |
| assert background is None | |
| device = verts.device # Get device info | |
| B = verts.shape[0] | |
| image_size = th.repeat_interleave(self.image_size[None], B, dim=0).to(device) | |
| cameras = cameras_from_opencv_projection(Rt[:,:,:3], Rt[:,:3,3], K, image_size) | |
| faces = self.vi[None].repeat(B, 1, 1).to(device) | |
| faces_uvs = self.vti[None].repeat(B, 1, 1).to(device) | |
| verts_uvs = self.vt[None].repeat(B, 1, 1).to(device) | |
| # In-place operation for flipping and permuting tensor | |
| if not self.flip_uvs: | |
| tex = tex.permute(0, 2, 3, 1).flip((1,)).to(device) | |
| textures = TexturesUV( | |
| maps=tex, | |
| faces_uvs=faces_uvs, | |
| verts_uvs=verts_uvs, | |
| ) | |
| meshes = Meshes(verts.to(device), faces, textures=textures) | |
| fragments = self.rasterizer(meshes, cameras=cameras) | |
| rgb = meshes.sample_textures(fragments)[:,:,:,0] | |
| rgb[fragments.pix_to_face[...,0] == -1] = 0.0 | |
| return {'render': rgb.permute(0, 3, 1, 2)} |