Model Training

Learn how to train a model for the image creation process.

Model training is a crucial step in the development of diffusion models. All the powerful generative models used in denoising and image synthesis are passed through this step. Here, we’ll explore the key steps and concepts used to train a diffusion model.

Press + to interact
Training a neural network
Training a neural network

Let’s discuss how to train this U-Net neural network and get it to predict noise. The goal of the neural network is to predict noise and learn the noise distribution on the image. It also needs to consider what isn’t noise. We do that by taking a character from our training data and adding noise to it. We add noise and feed it to the neural network. The next step is to compare the predicted noise against the actual noise added to that image, and that’s how the loss is computed. Through backpropagation, the network learns to predict that noise better.

We can just go through time and sampling and give it different noise levels. But realistically, in training, we don’t want the neural network to always look at the same character. It’s more stable if it looks at different characters across an epochAn epoch is a single pass through the entire training dataset by the learning algorithm., and it’s more uniform. We randomly sample what this time step could be. We then get the noise level appropriate to that time step. Next, we add it to this image and then have the neural network predict it. We then take the next image in our training data. We again sample a random time step noise. Then we add it to this character image, and again, we have the neural network predict the noise that was added. This results in a much more stable training scheme.

Press + to interact
Variation of samples accross epochs
Variation of samples accross epochs

Setting up an environment

The first step in the model definition is to set the necessary environment and define a neural network model. Let’s discuss all the steps involved in this process.

Importing libraries

The first step in model training is to import libraries that will be used throughout the process.

Press + to interact
import random
import imageio
import numpy as np
from argparse import ArgumentParser
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import einops
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Lambda
from torchvision.datasets.mnist import MNIST
# Setting reproducibility
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
# Definitions
STORE_PATH_MNIST = f"ddpm_model_mnist.pt"

Let’s discuss the breakdown of the imported libraries.

  • Lines 1–16: Here, we import the various libraries and modules necessary for executing the script.

  • Lines 19–22: We initialize a seed value for reproducibility across random number generation in Python, NumPy, and PyTorch.

  • Line 25: This is the constant storage of a file path for a trained model related to the MNIST dataset.

Setting hyperparameters

Here are a few options we should set before training our model:

Press + to interact
no_train = False
fashion = true
batch_size = 128
n_epochs = 20
lr = 0.001
store_path = "ddpm_mnist.pt"

Let’s look closer at these lines of code:

  • Line 1: We specify whether we want to skip the training loop and just use a pretrained model. If we haven’t trained a model already using this notebook, we keep this as False. If we want to use a pretrained model, we load it in the Jupyter filesystem.

  • Lines 2-4: We specify that batch_sizen_epochs, and lr are our typical training hyperparameters.

  • Line 6: We store the file name in which the model will be stored.

Displaying images

The following are two useful functions that are used to display images.

Press + to interact
def show_images(images, title=""):
"""Shows the provided images as subpictures in a square"""
# Converting images to CPU numpy arrays
if type(images) is torch.Tensor:
images = images.detach().cpu().numpy()
# Defining number of rows and columns
fig = plt.figure(figsize=(8, 8))
rows = int(len(images) ** (1 / 2))
cols = round(len(images) / rows)
# Populating figure with subplots
idx = 0
for r in range(rows):
for c in range(cols):
fig.add_subplot(rows, cols, idx + 1)
if idx < len(images):
plt.imshow(images[idx][0], cmap="gray")
idx += 1
fig.suptitle(title, fontsize=30)
# Showing the figure
plt.show()
def show_first_batch(loader):
for batch in loader:
show_images(batch[0], "Images in the first batch")
break

Let’s look at the breakdown of the two functions above:

  • Lines 1–2: The show_images method allows us to display images in a square-like pattern with a custom title. This function displays a collection of images as subplots in a square grid within a single figure. ...