VAEs in Action
Master VAEs for image generation, including data preprocessing, training techniques, architecture design, training stages, and image synthesis through encoding and sampling.
We'll cover the following...
Overview
In this lesson, we’ll delve into VAEs and their application to the MNIST dataset. Like traditional autoencoders, VAEs consist of an encoder and a decoder, but they go beyond simple image reconstruction. VAEs allow us to generate new data points by sampling from the learned latent space. Our focus will be on understanding the fascinating concept of generating new images from latent representations.
Loading and preprocessing the data
In this step, we load the MNIST dataset and prepare it for training the VAE. The dataset is converted to floating-point format, and the pixel values are normalized to the range [0, 1]. Additionally, we perform a visual inspection to understand the data distribution and validate the preprocessing. This ensures our VAE model can work effectively with the data and facilitates any necessary adjustments in the preprocessing pipeline.
Unique training approach in VAEs
In the context of VAEs, the conventional division of data into training and test sets is not required. Unlike classification tasks, where we need to evaluate the model’s performance on unseen data, VAEs aim to learn a latent representation of the entire dataset. All data is used for training the model, and new samples are generated from the learned latent space.
import torchimport torch.nn as nnimport torch.optim as optimimport torchvisionimport torch.nn.functional as nnF# Load and preprocess the MNIST datasetmnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True)x_train = mnist_train.datay_train = mnist_train.targetsx_train = x_train.float() / 255.0# Visualize the datasetimport matplotlib.pyplot as pltimport numpy as npnum_rows = 2num_columns = 5fig, axes = plt.subplots(num_rows, num_columns, figsize=(4, 2))for category in range(10):category_indices = np.where(y_train == category)[0]random_index = np.random.choice(category_indices)image_np = np.array(x_train[random_index])row = category // num_columnscol = category % num_columnsax = axes[row, col]ax.imshow(image_np, cmap='gray')ax.axis('off')plt.tight_layout()plt.show()
Defining the architecture
The autoencoder’s architecture has three main components.
Encoder
The encoder
class is a crucial part of the VAE architecture responsible for compressing the input MNIST image into a lower-dimensional latent space. It consists of three linear layers: self.input
, self.hidden_mean
, and self.hidden_std
. The encoder processes the input image through these layers and outputs the mean (mean
) and standard deviation (std
) vectors of the latent Gaussian distribution. These vectors represent the parameters ...