Improving GAN Training
Learn how we can improve GAN training and why these changes lead to model performance.
We'll cover the following...
Here we’ll try to fix the mode collapse and image clarity problems by trying to improve the training quality in our GAN. We’ve already seen some ideas for doing this when we developed refinements to our MNIST classifier.
Changing the loss function to BCELoss
The first refinement is to use the binary cross entropy BCELoss()
instead of the mean squared error MSELoss()
for the loss function.
📝 We’ve already talked about how binary cross entropy loss makes more sense when our network is performing a classification task. It also punishes incorrect answers, and rewards correct ones, more strongly than the
MSELoss()
.
self.loss_function = nn.BCELoss()
Changing the activation function to LeakyReLU
The next refinement we can make is to use LeakyReLU()
activation functions in both the discriminator and generator. We’ll only apply them after the middle layer, and keep the Sigmoid()
for the final layer as we want the outputs to be in the range 0 to 1.
📝 We’ve previously talked about how the
LeakyReLU()
activation reduces the problem of vanishing gradients for large signal values. It’s a popular way of improving the quality of training neural networks in general.
Normalizing the signals
Another refinement is to take the signals in the neural network and normalize them to ensure they are centered around a mean of zero and have their variance limited to avoid network saturation from large values.
📝 We’ve already seen how
LayerNorm()
has a positive effect on training.
The following code describes this improved discriminator neural network.
self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(0.02),nn.LayerNorm(200),nn.Linear(200, 1),nn.Sigmoid())
The code for the generator has the same changes.
self.model = nn.Sequential(nn.Linear(1, 200),nn.LeakyReLU(0.02),nn.LayerNorm(200),nn.Linear(200, 784),nn.Sigmoid())
Changing the optimizer
Another refinement we tried earlier was the Adam
optimiser. Let’s use it for both the discriminator and the generator.
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)
Let’s see the effect of these four changes.
Sadly, we still have mode collapse. The images themselves are ...