...

/

Generating Images from Labels with the CGAN

Generating Images from Labels with the CGAN

Create a model to generate images from labels with CGAN using MNIST dataset.

We have already defined the architecture of both generator and discriminator networks of the CGAN. Now, let’s write the code for model training. In order to make it easy to reproduce the results, we will use MNIST as the training set to see how the CGAN performs in image generation. What we want to accomplish here is that, after the model is trained, it can generate the correct digit image we tell it to, with extensive variety.

One-stop model training API

First, let’s create a new Model class that serves as a wrapper for different models and provides the one-stop training API. Create a new file named build_gan.py and import the necessary modules:

Press + to interact
import os
import numpy as np
import torch
import torchvision.utils as vutils
from cgan import Generator as cganG
from cgan import Discriminator as cganD

Then, let's create the Model class. In this class, we will initialize the Generator and Discriminator modules and provide train and eval methods so that users can simply call Model.train() or Model.eval() somewhere else to complete the model training or evaluation.

Press + to interact
class Model(object):
def __init__(self, name, device, data_loader, classes, channels, img_size, latent_dim):
self.name = name
self.device = device
self.data_loader = data_loader
self.classes = classes
self.channels = channels
self.img_size = img_size
self.latent_dim = latent_dim
if self.name == 'cgan':
self.netG = cganG(self.classes, self.channels, self.img_size, self.latent_dim)
self.netG.to(self.device)
if self.name == 'cgan':
self.netD = cganD(self.classes, self.channels, self.img_size, self.latent_dim)
self.netD.to(self.device)
self.optim_G = None
self.optim_D = None

Here, the generator network, netG, and the discriminator network, netD, are initialized based on the class number (classes), image channel (channels), image size (img_size), and length of the latent vector(latent_dim). These arguments will be given later. For now, let's assume that these values are already known. Since we need to ...