Model Training
Learn how to train a model for the image creation process.
We'll cover the following...
Model training is a crucial step in the development of diffusion models. All the powerful generative models used in denoising and image synthesis are passed through this step. Here, we’ll explore the key steps and concepts used to train a diffusion model.
Let’s discuss how to train this U-Net neural network and get it to predict noise. The goal of the neural network is to predict noise and learn the noise distribution on the image. It also needs to consider what isn’t noise. We do that by taking a character from our training data and adding noise to it. We add noise and feed it to the neural network. The next step is to compare the predicted noise against the actual noise added to that image, and that’s how the loss is computed. Through backpropagation, the network learns to predict that noise better.
We can just go through time and sampling and give it different noise levels. But realistically, in training, we don’t want the neural network to always look at the same character. It’s more stable if it looks at different characters across an
Setting up an environment
The first step in the model definition is to set the necessary environment and define a neural network model. Let’s discuss all the steps involved in this process.
Importing libraries
The first step in model training is to import libraries that will be used throughout the process.
import randomimport imageioimport numpy as npfrom argparse import ArgumentParserfrom tqdm.auto import tqdmimport matplotlib.pyplot as pltimport einopsimport torchimport torch.nn as nnfrom torch.optim import Adamfrom torch.utils.data import DataLoaderfrom torchvision.transforms import Compose, ToTensor, Lambdafrom torchvision.datasets.mnist import MNIST# Setting reproducibilitySEED = 0random.seed(SEED)np.random.seed(SEED)torch.manual_seed(SEED)# DefinitionsSTORE_PATH_MNIST = f"ddpm_model_mnist.pt"
Let’s discuss the breakdown of the imported libraries.
Lines 1–16: Here, we import the various libraries and modules necessary for executing the script.
Lines 19–22: We initialize a seed value for reproducibility across random number generation in Python, NumPy, and PyTorch.
Line 25: This is the constant storage of a file path for a trained model related to the
MNIST
dataset.
Setting hyperparameters
Here are a few options we should set before training our model:
no_train = Falsefashion = truebatch_size = 128n_epochs = 20lr = 0.001store_path = "ddpm_mnist.pt"
Let’s look closer at these lines of code:
Line 1: We specify whether we want to skip the training loop and just use a pretrained model. If we haven’t trained a model already using this notebook, we keep this as
False
. If we want to use a pretrained model, we load it in the Jupyter filesystem.Lines 2-4: We specify that
batch_size
,n_epochs
, andlr
are our typical training hyperparameters.Line 6: We store the file name in which the model will be stored.
Displaying images
The following are two useful functions that are used to display images.
def show_images(images, title=""):"""Shows the provided images as subpictures in a square"""# Converting images to CPU numpy arraysif type(images) is torch.Tensor:images = images.detach().cpu().numpy()# Defining number of rows and columnsfig = plt.figure(figsize=(8, 8))rows = int(len(images) ** (1 / 2))cols = round(len(images) / rows)# Populating figure with subplotsidx = 0for r in range(rows):for c in range(cols):fig.add_subplot(rows, cols, idx + 1)if idx < len(images):plt.imshow(images[idx][0], cmap="gray")idx += 1fig.suptitle(title, fontsize=30)# Showing the figureplt.show()def show_first_batch(loader):for batch in loader:show_images(batch[0], "Images in the first batch")break
Let’s look at the breakdown of the two functions above:
Lines 1–2: The
show_images
method allows us to display images in a square-like pattern with a custom title. This function displays a collection of images as subplots in a square grid within a single figure. ...