...
/Masked Autoencoders: Masking and Encoder
Masked Autoencoders: Masking and Encoder
Learn how to implement the masking strategy and encoder layer of Masked Autoencoders (MAE).
We'll cover the following...
Similar to SimMIM, a Masked Autoencoder (MAE) reconstructs the randomly masked image patches in the image pixel space by using an asymmetric decoder-encoder design where the encoder only sees visible patches (i.e., masked patches don’t participate in input). The decoder (lightweight) reconstructs the input along with the masked tokens. The figure below illustrates the idea.
Masking strategy
Like SimMIM, MAE also masks the input patches using random sampling. However, it uses a high masking ratio (i.e., the ratio of removed patches), which makes the reconstruction challenging as the task will be difficult to solve by just looking at neighboring patches. The code snippet below masks the input image patches (
Note: The 80 percent masking ratio means that 80 percent of the image patches are being masked.
import torchimport numpy as npfrom PIL import Imageclass MaskGenerator:def __init__(self, input_size=192, mask_patch_size=16, mask_ratio=0.8):self.input_size = input_sizeself.mask_patch_size = mask_patch_sizeself.mask_ratio = mask_ratioassert self.input_size % self.mask_patch_size == 0self.rand_size = self.input_size // self.mask_patch_sizeself.token_count = self.rand_size ** 2self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))def __call__(self, images):mask_ids, ids_keeps, masks = [], [], []for image in images:permutation = np.random.permutation(self.token_count)mask_idx, ids_keep = permutation[:self.mask_count], permutation[self.mask_count:]mask = np.ones(self.token_count, dtype=int)mask[mask_idx] = 0mask = mask.reshape((self.rand_size, self.rand_size))mask = mask.repeat(self.mask_patch_size, axis=0).repeat(self.mask_patch_size, axis=1)mask_ids.append(mask_idx)ids_keeps.append(ids_keep)masks.append(mask)return mask_ids, ids_keeps, masksimages = [Image.open("n02107683_Bernese_mountain_dog.jpeg").resize((224,224)),Image.open("cat.jpg").resize((224, 224))]images[0].save("./output/image.png")images = np.stack([np.array(image) for image in images], 0)generator = MaskGenerator(input_size=224)mask_ids, ids_keeps, masks = generator(images)mask = Image.fromarray((255*masks[0]).astype(np.uint8))mask.save("./output/mask.png")masked_image = np.expand_dims(masks[0], -1)*images[0]masked_image = Image.fromarray(masked_image.astype(np.uint8))masked_image.save("./output/masked_image.png")
Line 5: We define a class,
MaskGenerator
, which takes a set of images and generates random masks for them.Lines 7–16: We implement the
__init__
function that takes the image’sinput_size
,mask_patch_size
, and masking ratiomask_ratio
as input parameters and calculates the total number of patchesself.token_count
and the number of patches to maskself.mask_count
.Lines 18–29: We implement the
__call__
function, which takes a set ofimages
as input and generates masks.Lines 22–25: We first sample the random patch indexes,
mask_idx
, which will be masked. The rest of the patch ...