...
/Generating Images from Labels with the CGAN
Generating Images from Labels with the CGAN
Create a model to generate images from labels with CGAN using MNIST dataset.
We'll cover the following...
We have already defined the architecture of both generator and discriminator networks of the CGAN. Now, let’s write the code for model training. In order to make it easy to reproduce the results, we will use MNIST as the training set to see how the CGAN performs in image generation. What we want to accomplish here is that, after the model is trained, it can generate the correct digit image we tell it to, with extensive variety.
One-stop model training API
First, let’s create a new Model
class that serves as a wrapper for different models and provides the one-stop training API. Create a new file named build_
gan.py
and import the necessary modules:
import osimport numpy as npimport torchimport torchvision.utils as vutilsfrom cgan import Generator as cganGfrom cgan import Discriminator as cganD
Then, let's create the Model
class. In this class, we will initialize the Generator
and Discriminator
modules and provide train
and eval
methods so that users can simply call Model.train()
or Model.eval()
somewhere else to complete the model training or evaluation.
class Model(object):def __init__(self, name, device, data_loader, classes, channels, img_size, latent_dim):self.name = nameself.device = deviceself.data_loader = data_loaderself.classes = classesself.channels = channelsself.img_size = img_sizeself.latent_dim = latent_dimif self.name == 'cgan':self.netG = cganG(self.classes, self.channels, self.img_size, self.latent_dim)self.netG.to(self.device)if self.name == 'cgan':self.netD = cganD(self.classes, self.channels, self.img_size, self.latent_dim)self.netD.to(self.device)self.optim_G = Noneself.optim_D = None
Here, the generator network, netG
, and the discriminator network, netD
, are initialized based on the class number (classes
), image channel (channels
), image size (img_size
), and length of the latent vector(latent_dim
). These arguments will be given later. For now, let's assume that these values are already known. Since we need to ...