...

/

Masked Autoencoders: Masking and Encoder

Masked Autoencoders: Masking and Encoder

Learn how to implement the masking strategy and encoder layer of Masked Autoencoders (MAE).

We'll cover the following...

Similar to SimMIM, a Masked Autoencoder (MAE) reconstructs the randomly masked image patches in the image pixel space by using an asymmetric decoder-encoder design where the encoder only sees visible patches (i.e., masked patches don’t participate in input). The decoder (lightweight) reconstructs the input along with the masked tokens. The figure below illustrates the idea.

Masking strategy

Like SimMIM, MAE also masks the input patches using random sampling. However, it uses a high masking ratio (i.e., the ratio of removed patches), which makes the reconstruction challenging as the task will be difficult to solve by just looking at neighboring patches. The code snippet below masks the input image patches (16×1616\times 16) using an 80 percent masking ratio.

Note: The 80 percent masking ratio means that 80 percent of the image patches are being masked.

Press + to interact
import torch
import numpy as np
from PIL import Image
class MaskGenerator:
def __init__(self, input_size=192, mask_patch_size=16, mask_ratio=0.8):
self.input_size = input_size
self.mask_patch_size = mask_patch_size
self.mask_ratio = mask_ratio
assert self.input_size % self.mask_patch_size == 0
self.rand_size = self.input_size // self.mask_patch_size
self.token_count = self.rand_size ** 2
self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
def __call__(self, images):
mask_ids, ids_keeps, masks = [], [], []
for image in images:
permutation = np.random.permutation(self.token_count)
mask_idx, ids_keep = permutation[:self.mask_count], permutation[self.mask_count:]
mask = np.ones(self.token_count, dtype=int)
mask[mask_idx] = 0
mask = mask.reshape((self.rand_size, self.rand_size))
mask = mask.repeat(self.mask_patch_size, axis=0).repeat(self.mask_patch_size, axis=1)
mask_ids.append(mask_idx)
ids_keeps.append(ids_keep)
masks.append(mask)
return mask_ids, ids_keeps, masks
images = [Image.open("n02107683_Bernese_mountain_dog.jpeg").resize((224,224)),
Image.open("cat.jpg").resize((224, 224))]
images[0].save("./output/image.png")
images = np.stack([np.array(image) for image in images], 0)
generator = MaskGenerator(input_size=224)
mask_ids, ids_keeps, masks = generator(images)
mask = Image.fromarray((255*masks[0]).astype(np.uint8))
mask.save("./output/mask.png")
masked_image = np.expand_dims(masks[0], -1)*images[0]
masked_image = Image.fromarray(masked_image.astype(np.uint8))
masked_image.save("./output/masked_image.png")
  1. Line 5: We define a class, MaskGenerator, which takes a set of images and generates random masks for them.

  2. Lines 7–16: We implement the __init__ function that takes the image’s input_size, mask_patch_size, and masking ratio mask_ratio as input parameters and calculates the total number of patches self.token_count and the number of patches to mask self.mask_count.

  3. Lines 18–29: We implement the __call__ function, which takes a set of images as input and generates masks.

  4. Lines 22–25: We first sample the random patch indexes, mask_idx, which will be masked. The rest of the patch ...