Training the MNIST GAN
Learn about training the GAN.
We'll cover the following...
Let’s train this GAN.
The training loop
The training loop is the same as before. The only changes are the data passed to the discriminator and the generator.
Press + to interact
# create Discriminator and GeneratorD = Discriminator()G = Generator()# train Discriminator and Generatorfor label, image_data_tensor, target_tensor in mnist_dataset:# train discriminator on trueD.train(image_data_tensor, torch.FloatTensor([1.0]))# train discriminator on false# use detach() so gradients in G are not calculatedD.train(G.forward(generate_random(1)).detach(), torch.FloatTensor([0.0]))# train generatorG.train(D, generate_random(1), torch.FloatTensor([1.0]))pass
It takes a short while to complete the training. For us, it took just over four minutes.
The counter is printed every 10,000 and grows to 120,000 because the discriminator is trained with 60,000 MNIST images and 60,000 generated images.
Discriminator loss during training
Let’s plot the loss values from training the discriminator.
That’s an interesting chart! The loss values fall towards zero and stay low for a while, suggesting the discriminator is ahead of the generator. Then the ...