Training The 1010 Pattern GAN
Learn how to train the GAN.
Training the 1010 pattern GAN
Finally, we’re ready to train the GAN using the 3 step training loop. Have a look at the following code.
# create Discriminator and GeneratorD = Discriminator()G = Generator()# train Discriminator and Generatorfor i in range(10000):# train discriminator on trueD.train(generate_real(), torch.FloatTensor([1.0]))# train discriminator on false# use detach() so gradients in G are not calculatedD.train(G.forward(torch.FloatTensor([0.5])).detach(), torch.FloatTensor([0.0]))# train generatorG.train(D, torch.FloatTensor([0.5]), torch.FloatTensor([1.0]))pass
-
We first create fresh discriminator and generator objects,
D
, andG
, before running the training loop 10,000 times. -
Inside the loop, we can see the 3 steps of the GAN training loop we talked about earlier.
-
For step 1, we can see the discriminator being trained on the real data.
-
For step 2, we train the discriminator with a pattern from the generator. That
detach()
applied to the output of the generator detaches it from the computation graph. Normally, callingbackwards()
on the discriminator loss causes error gradients to be calculated all the way back along with the computation graph - from the discriminator loss, through the discriminator itself, and then back through the generator. Because we’re only training the discriminator, we don’t need to calculate the gradients for the generator. Thatdetach()
applied to the generator output cuts the computation graph at that point. The following picture illustrates this.
-
You might ask why we do this. ...