...

/

Masked Autoencoders: Decoder and Loss Function

Masked Autoencoders: Decoder and Loss Function

Learn how to implement the decoder layer and loss function of Masked Autoencoders (MAE).

We'll cover the following...

Decoder

The input to the MAE decoder consists of all the tokens—that is:

  • Encoded visible patches, and

  • Mask tokens.

Similar to SimMIM, a shared masked token vector is used as a substitute for the missing or masked patches in the input. The full set of tokens is passed through a transformer network containing self-attention layers.

The goal of the MAE decoder is to perform the image reconstruction task. Note that the MAE decoder is only used during the pre-training step (i.e., only the encoder is used in the transfer learning step). The design of the decoder can be flexible. You can opt for shallow decoders to incur minimum training overhead.

The code below implements the decoder layer of MAEs.

Press + to interact
main.py
utils.py
import torch
import numpy as np
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as T
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import VisionTransformer
import warnings
warnings.simplefilter("ignore", UserWarning)
class MaskGenerator:
def __init__(self, input_size=192, mask_patch_size=16, mask_ratio=0.8):
self.input_size = input_size
self.mask_patch_size = mask_patch_size
self.mask_ratio = mask_ratio
assert self.input_size % self.mask_patch_size == 0
self.rand_size = self.input_size // self.mask_patch_size
self.token_count = self.rand_size ** 2
self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
def __call__(self, images):
mask_ids, ids_keeps, masks = [], [], []
for image in images:
permutation = np.random.permutation(self.token_count)
mask_idx, ids_keep = permutation[:self.mask_count], permutation[self.mask_count:]
mask = np.ones(self.token_count, dtype=int)
mask[mask_idx] = 0
mask = mask.reshape((self.rand_size, self.rand_size))
mask = mask.repeat(self.mask_patch_size, axis=0).repeat(self.mask_patch_size, axis=1)
mask_ids.append(mask_idx)
ids_keeps.append(ids_keep)
masks.append(mask)
return mask_ids, ids_keeps, masks
class EncoderForMAE(VisionTransformer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
assert self.num_classes == 0
def forward(self, x, mask_ids, ids_keep):
x = self.patch_embed(x)
B, L, D = x.shape
x = x + self.pos_embed[:, 1:, :]
x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
cls_tokens = cls_tokens + + self.pos_embed[:, :1, :]
x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
  1. Lines 13–18: We define a class, DecoderForMAE, which inherits from the VisionTransformer class. It takes the encoder embedding size, encoder_embed_size, as input and initializes a projection layer, self.decoder_embed (which projects encoder features of size encoder_embed_size to decoder input space of size self.embed_dim), the learnable mask token self.mask_token , and the decoder prediction head, self.decoder_pred.

  2. Lines 21–38: We define the forward() function, which takes encoder features, x, masked indexes, mask_ids, and visible patch indexes, ids_keep, as inputs, and returns the reconstructed image patches.

  3. Line 22: We project the encoder features, x, to the decoder input space embeddings using a linear projection layer, self.decoder_embed. This projection layer projects the encoder features (B×(Pvisible+1)×dB \times (P_{\text{visible}} +1) \times d) to the decoder input space (B×(Pvisible+1)×dB \times (P_{\text{visible}} +1) \times d') . Here, ...