...

/

Training the MNIST GAN

Training the MNIST GAN

Learn about training the GAN.

Let’s train this GAN.

The training loop

The training loop is the same as before. The only changes are the data passed to the discriminator and the generator.

Press + to interact
# create Discriminator and Generator
D = Discriminator()
G = Generator()
# train Discriminator and Generator
for label, image_data_tensor, target_tensor in mnist_dataset:
# train discriminator on true
D.train(image_data_tensor, torch.FloatTensor([1.0]))
# train discriminator on false
# use detach() so gradients in G are not calculated
D.train(G.forward(generate_random(1)).detach(), torch.FloatTensor([0.0]))
# train generator
G.train(D, generate_random(1), torch.FloatTensor([1.0]))
pass

It takes a short while to complete the training. For us, it took just over four minutes.

The counter is printed every 10,000 and grows to 120,000 because the discriminator is trained with 60,000 MNIST images and 60,000 generated images.

Discriminator loss during training

Let’s plot the loss values from training the discriminator.

That’s an interesting chart! The loss values fall towards zero and stay low for a while, suggesting the discriminator is ahead of the generator. Then the ...