...

/

Improving GAN Training

Improving GAN Training

Learn how we can improve GAN training and why these changes lead to model performance.

Here we’ll try to fix the mode collapse and image clarity problems by trying to improve the training quality in our GAN. We’ve already seen some ideas for doing this when we developed refinements to our MNIST classifier.

Changing the loss function to BCELoss

The first refinement is to use the binary cross entropy BCELoss() instead of the mean squared error MSELoss() for the loss function.

📝 We’ve already talked about how binary cross entropy loss makes more sense when our network is performing a classification task. It also punishes incorrect answers, and rewards correct ones, more strongly than the MSELoss().

Press + to interact
self.loss_function = nn.BCELoss()

Changing the activation function to LeakyReLU

The next refinement we can make is to use LeakyReLU() activation functions in both the discriminator and generator. We’ll only apply them after the middle layer, and keep the Sigmoid() for the final layer as we want the outputs to be in the range 0 to 1.

📝 We’ve previously talked about how the LeakyReLU() activation reduces the problem of vanishing gradients for large signal values. It’s a popular way of improving the quality of training neural networks in general.

Normalizing the signals

Another refinement is to take the signals in the neural network and normalize them to ensure they are centered around a mean of zero and have their variance limited to avoid network saturation from large values.

📝 We’ve already seen how LayerNorm() has a positive effect on training.

The following code describes this improved discriminator neural network.

Press + to interact
self.model = nn.Sequential(
nn.Linear(784, 200),
nn.LeakyReLU(0.02),
nn.LayerNorm(200),
nn.Linear(200, 1),
nn.Sigmoid()
)

The code for the generator has the same changes.

Press + to interact
self.model = nn.Sequential(
nn.Linear(1, 200),
nn.LeakyReLU(0.02),
nn.LayerNorm(200),
nn.Linear(200, 784),
nn.Sigmoid()
)

Changing the optimizer

Another refinement we tried earlier was the Adam optimiser. Let’s use it for both the discriminator and the generator.

Press + to interact
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

Let’s see the effect of these four changes.

Sadly, we still have mode collapse. The images themselves are ...