...

/

Training The 1010 Pattern GAN

Training The 1010 Pattern GAN

Learn how to train the GAN.

Training the 1010 pattern GAN

Finally, we’re ready to train the GAN using the 3 step training loop. Have a look at the following code.

Press + to interact
# create Discriminator and Generator
D = Discriminator()
G = Generator()
# train Discriminator and Generator
for i in range(10000):
# train discriminator on true
D.train(generate_real(), torch.FloatTensor([1.0]))
# train discriminator on false
# use detach() so gradients in G are not calculated
D.train(G.forward(torch.FloatTensor([0.5])).detach(), torch.FloatTensor([0.0]))
# train generator
G.train(D, torch.FloatTensor([0.5]), torch.FloatTensor([1.0]))
pass
  • We first create fresh discriminator and generator objects, D, and G, before running the training loop 10,000 times.

  • Inside the loop, we can see the 3 steps of the GAN training loop we talked about earlier.

    • For step 1, we can see the discriminator being trained on the real data.

    • For step 2, we train the discriminator with a pattern from the generator. That detach() applied to the output of the generator detaches it from the computation graph. Normally, calling backwards() on the discriminator loss causes error gradients to be calculated all the way back along with the computation graph - from the discriminator loss, through the discriminator itself, and then back through the generator. Because we’re only training the discriminator, we don’t need to calculate the gradients for the generator. That detach() applied to the generator output cuts the computation graph at that point. The following picture illustrates this.

You might ask why we do this. ...