Training the Fashion MNIST cGAN
Learn about training the conditional GAN, the results of conditional GAN, loss during training the discriminator and the generator, as well as the generator output for the Fashion MNIST dataset.
We'll cover the following...
The training loop
The main GAN training loop is updated to pass a label tensor to the discriminator and generator. The following shows only the code inside the epoch loop.
Press + to interact
for label, image_data_tensor, label_tensor in mnist_dataset:# train discriminator on trueD.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))# random 1-hot label for generatorrandom_label = generate_random_one_hot(10)# train discriminator on false# use detach() so gradients in G are not calculatedD.train(G.forward(generate_random_seed(100), random_label).detach(), random_label, torch.FloatTensor([0.0]))# different random 1-hot label for generatorrandom_label = generate_random_one_hot(10)# train generatorG.train(D, generate_random_seed(100), random_label, torch.FloatTensor([1.0]))pass
After training, we can plot the images, using the plot_image
function of the generator class.
The conditional GAN results
The code ...