...
/Masked Autoencoders: Decoder and Loss Function
Masked Autoencoders: Decoder and Loss Function
Learn how to implement the decoder layer and loss function of Masked Autoencoders (MAE).
We'll cover the following...
Decoder
The input to the MAE decoder consists of all the tokens—that is:
Encoded visible patches, and
Mask tokens.
Similar to SimMIM, a shared masked token vector is used as a substitute for the missing or masked patches in the input. The full set of tokens is passed through a transformer network containing self-attention layers.
The goal of the MAE decoder is to perform the image reconstruction task. Note that the MAE decoder is only used during the pre-training step (i.e., only the encoder is used in the transfer learning step). The design of the decoder can be flexible. You can opt for shallow decoders to incur minimum training overhead.
The code below implements the decoder layer of MAEs.
import torchimport numpy as npfrom PIL import Imageimport torch.nn as nnimport torch.nn.functional as Fimport torchvision.transforms.functional as Tfrom timm.models.layers import trunc_normal_from timm.models.vision_transformer import VisionTransformerimport warningswarnings.simplefilter("ignore", UserWarning)class MaskGenerator:def __init__(self, input_size=192, mask_patch_size=16, mask_ratio=0.8):self.input_size = input_sizeself.mask_patch_size = mask_patch_sizeself.mask_ratio = mask_ratioassert self.input_size % self.mask_patch_size == 0self.rand_size = self.input_size // self.mask_patch_sizeself.token_count = self.rand_size ** 2self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))def __call__(self, images):mask_ids, ids_keeps, masks = [], [], []for image in images:permutation = np.random.permutation(self.token_count)mask_idx, ids_keep = permutation[:self.mask_count], permutation[self.mask_count:]mask = np.ones(self.token_count, dtype=int)mask[mask_idx] = 0mask = mask.reshape((self.rand_size, self.rand_size))mask = mask.repeat(self.mask_patch_size, axis=0).repeat(self.mask_patch_size, axis=1)mask_ids.append(mask_idx)ids_keeps.append(ids_keep)masks.append(mask)return mask_ids, ids_keeps, masksclass EncoderForMAE(VisionTransformer):def __init__(self, **kwargs):super().__init__(**kwargs)assert self.num_classes == 0def forward(self, x, mask_ids, ids_keep):x = self.patch_embed(x)B, L, D = x.shapex = x + self.pos_embed[:, 1:, :]x = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thankscls_tokens = cls_tokens + + self.pos_embed[:, :1, :]x = torch.cat((cls_tokens, x), dim=1)x = self.pos_drop(x)for blk in self.blocks:x = blk(x)x = self.norm(x)return x
Lines 13–18: We define a class,
DecoderForMAE
, which inherits from theVisionTransformer
class. It takes the encoder embedding size,encoder_embed_size
, as input and initializes a projection layer,self.decoder_embed
(which projects encoder features of sizeencoder_embed_size
to decoder input space of sizeself.embed_dim
), the learnable mask tokenself.mask_token
, and the decoder prediction head,self.decoder_pred
.Lines 21–38: We define the
forward()
function, which takes encoder features,x
, masked indexes,mask_ids
, and visible patch indexes,ids_keep
, as inputs, and returns the reconstructed image patches.Line 22: We project the encoder features,
x
, to the decoder input space embeddings using a linear projection layer,self.decoder_embed
. This projection layer projects the encoder features () to the decoder input space ( ) . Here, ...