...

/

Training the Fashion MNIST cGAN

Training the Fashion MNIST cGAN

Learn about training the conditional GAN, the results of conditional GAN, loss during training the discriminator and the generator, as well as the generator output for the Fashion MNIST dataset.

The training loop

The main GAN training loop is updated to pass a label tensor to the discriminator and generator. The following shows only the code inside the epoch loop.

Press + to interact
for label, image_data_tensor, label_tensor in mnist_dataset:
# train discriminator on true
D.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))
# random 1-hot label for generator
random_label = generate_random_one_hot(10)
# train discriminator on false
# use detach() so gradients in G are not calculated
D.train(G.forward(generate_random_seed(100), random_label).detach(), random_label, torch.FloatTensor([0.0]))
# different random 1-hot label for generator
random_label = generate_random_one_hot(10)
# train generator
G.train(D, generate_random_seed(100), random_label, torch.FloatTensor([1.0]))
pass

After training, we can plot the images, using the plot_image function of the generator class.

The conditional GAN results

The code ...