Mesh R-CNN
Learn about the Mesh R-CNN architecture and how to use it to predict 3D models from real-world images.
Overview
Mesh R-CNN is a landmark model in the world of 3D deep learning. It is one of the first computer vision models for 3D shape prediction that works on real-world images. Based on the modular R-CNN design, it relies on much of the same architecture but introduces a new mesh prediction branch with several key innovations. We’ll begin by introducing Mesh R-CNN, then delve into details on the mesh prediction branch, and follow up with some code examples. We lack the time, data, and compute needed to present the entire Mesh R-CNN project, but we’ll explore several code examples that implement the key components.
Introduction to Mesh R-CNN
First introduced in the paper “Mesh R-CNN” in 2020, the Mesh R-CNN architecture builds upon the prior research into the R-CNN models like Faster R-CNN and Mask R-CNN. As a result, pretrained Mask R-CNN models can be used as the backbone to incorporate strong priors into the Mesh R-CNN’s predictions. Mesh R-CNN introduces a mesh prediction branch, which processes image features through a series of branches with the intent of predicting a 3D model for the detected object.
Mesh R-CNN has a number of features that make it both innovative for its time and usable today. Some of these features include:
Predicts arbitrary (untextured) 3D meshes from a single image
Works on real-world images
Can use pretrained backends that have been trained on general computer vision tasks (such as ResNet)
Doesn’t require a template mesh
Review of Mask R-CNN
Since Mesh R-CNN relies heavily on Mask R-CNN, we first take a quick review of the Mask R-CNN architecture. Mask R-CNN is built upon the Faster R-CNN architecture for object detection, which consists of two sequential stages:
A region proposal network (RPN) uses a convolutional neural network to propose candidate bounding boxes.
The RoIPool stage aggregates features from the bounding boxes for classification and regression.
Mask R-CNN makes a key contribution called RoIAlign, a technique that enables segmentation mask prediction. It is a variation of another technique in Faster R-CNN, called RoIPool, that is used to pool features from bounding boxes. RoI, or region of interest, refers to gathering relevant information from local regions in an image. Unlike RoIPool, which simply pools this information, RoIAlign applies bilinear interpolation to the underlying feature map to interpolate features for each output point. This enables it to gather local information without losing resolution, which is essential for segmentation tasks.
Next, we take a deeper look at the mesh prediction branch.
Mesh prediction branch
The mesh prediction branch seeks to map the input features from RoIAlign to a 3D triangle mesh that best approximates the shape of the object. From the Mask R-CNN backbone we’re given a collection of image features from the 2D image space.
The architecture shows that information from the RoIAlign operation flows into two separate branches: the box/mask branch and the voxel branch. The box/mask branch is identical to that found in Mask R-CNN. Mesh R-CNN introduces the voxel branch to predict meshes. Its duty is to predict a voxel grid, convert it into a mesh, and then refine the predicted mesh for additional detail. We’ll now introduce each of these steps in closer detail.
Voxel prediction branch
The voxel prediction branch is essentially a generalization of the mask prediction branch. Instead of predicting a 2D mask of occupancy estimates, it predicts a 3D voxel grid of occupancy estimates. A small fully convolutional network (
The following code depicts an example implementation of a voxel prediction branch. This is based off of the official Mesh R-CNN implementation.
import torchfrom torch import nnfrom torch.nn import functional as Fclass VoxelPredictionLayer(nn.Module):"""A voxel head with several conv layers, plus an upsample layer (with `ConvTranspose2d`)"""def __init__(self,input_channels: int):"""Args:input_channels: The number of channels in input tensor"""super(VoxelPredictionLayer, self).__init__()# Number of Mesh R-CNN classesself.num_classes = 9# Dimensions of conv layersconv_dims = 256# The number of convs in the voxel head and the number of channelsnum_conv = 4# The number of depth channels for the predicted voxelsself.num_depth = 28# Define conv layersself.conv_norm_relus = []for k in range(num_conv):conv = nn.Conv2d(input_channels if k == 0 else conv_dims,conv_dims,kernel_size=3,stride=1,padding=1,bias=True)self.add_module("voxel_fcn{}".format(k + 1), conv)self.conv_norm_relus.append(conv)# Define deconv layersself.deconv = nn.ConvTranspose2d(conv_dims if num_conv > 0 else input_channels,conv_dims,kernel_size=2,stride=2,padding=0,)# Define output (prediction) layerself.predictor = nn.Conv2d(conv_dims,self.num_classes * self.num_depth,kernel_size=1,stride=1,padding=0)# Use normal distribution initialization for voxel prediction layernn.init.normal_(self.predictor.weight, std=0.001)if self.predictor.bias is not None:nn.init.constant_(self.predictor.bias, 0)def forward(self, x):# Forward pass through layersfor layer in self.conv_norm_relus:x = F.relu(layer(x))# Apply deconv layerx = F.relu(self.deconv(x))# Apply output layerx = self.predictor(x)# Reshape from (N, CD, H, W) to (N, C, D, H, W)x = x.reshape(x.size(0), self.num_classes, self.num_depth, x.size(2), x.size(3))return x# Create an instance of a voxel prediction layervoxel_head = VoxelPredictionLayer(4)print(voxel_head)# Test the forward passx = torch.randn([1, 4, 8, 8])output = voxel_head(x)print(output.shape)
The main objective of this layer is to convert the representation from a 2D grid into a 3D grid. This is implemented in a straightforward fashion with a single reshape
function. Otherwise, this layer is a simple FCN consisting of the Conv2d
and ConvTranspose2d
layers. The ConvTranspose2d
layer increases the output resolution of the branch to provide more voxels to work with.
Keep in mind the effect of the camera frustum. The 3D area that each pixel covers increases with distance from the camera. Since the camera frustum effect distorts the shape of voxels, the voxel space is transformed by the camera intrinsic matrix
Voxel2mesh conversion
The voxel2mesh conversion step starts as a necessary bridge between the voxel prediction and mesh refinement branches. After all, to refine a mesh we first need a mesh. Given the output voxel grid of densities from the voxel prediction branch, the voxel2mesh stage simply applies the familiar cubify
Recall, that the cubify
operator takes a grid of input voxels with occupancy values