Implementing Vanilla GAN using Keras

Generative Adversarial Networks (GANs) have become an increasingly famous topic in AI due to their ability to generate high-quality data across various domains, from images to music and beyond. Among the different GAN variants, vanilla GAN stands out as the fundamental architecture on which many other GANs are built. Here, we will explore the workings of vanilla GAN and implement it from scratch.

Vanilla GAN architecture

At its core, a vanilla GAN consists of two neural networks: a generator GG and a discriminator DD. The figure below shows a simplified workflow of the vanilla GAN.

Typical architecture of the vanilla GAN
Typical architecture of the vanilla GAN

The generator aims to generate synthetic data samples that resemble the real data pdata(x)\sim p_{data}(x), while the discriminator tries to differentiate between the real xx and fake G(z)G(z)samples. These two networks engage in a min-max game, where the generator aims to fool the discriminator by producing realistic samples, and the discriminator aims to differentiate between real and fake samples accurately.

Implementation of vanilla GAN

Here is a step-by-step implementation of vanilla GAN:

  1. Import the necessary libraries for creating and visualizing the GAN:

import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import keras
from keras.models import Model, Sequential
from keras.layers import Dense, BatchNormalization, Reshape, Dropout, LeakyReLU, Input, Flatten
from keras.optimizers import Adam
from keras.datasets import mnist
from keras.utils import plot_model
  1. Define the required parameters for the model:

# Number of epochs for training
epochs = 20000
# Shape of the MNIST images (28x28 pixels with 1 channel)
mnist_shape = (28, 28, 1)
# Batch size for training
batch_size = 128
# Shape of the noise input to the generator (100-dimensional vector)
noise_shape = (100,)
# Interval for saving generated images and models during training
save_every = 1000
  1. Define and build the generator network:

def build_generator(noise_shape, mnist_shape):
# Define input layer for the generator model
noise = Input(shape=noise_shape)
# Fully connected layer with 256 units
x = Dense(256, input_shape=(noise_shape))(noise)
x = LeakyReLU(alpha=0.2)(x)
x = BatchNormalization(momentum=0.8)(x)
# Fully connected layer with 512 units
x = Dense(512)(x)
x = LeakyReLU(alpha=0.2)(x)
x = BatchNormalization(momentum=0.8)(x)
# Fully connected layer with 1024 units
x = Dense(1024)(x)
x = LeakyReLU(alpha=0.2)(x)
x = BatchNormalization(momentum=0.8)(x)
# Output layer with units equal to the number of features in the output image
# Use tanh activation to ensure pixel values are in the range [-1, 1]
x = Dense(np.prod(mnist_shape), activation='tanh')(x)
# Reshape output to match the desired image shape
x = Reshape(mnist_shape)(x)
# Create the generator model
model = Model(noise, x)
# Forward pass through the model
img = model(noise)
# Return both the generator model and its output
return Model(noise, img)
# Create generator model
# Call build_generator function to create generator model
generator = build_generator(noise_shape, mnist_shape)
generator.summary()
  1. Define and build the discriminator network:

def build_discriminator(mnist_shape):
# Define input layer for the discriminator model
input_img = Input(shape=mnist_shape)
# Flatten the input image
x = Flatten()(input_img)
# First fully connected layer with 512 units
x = Dense(512)(x)
x = LeakyReLU(alpha=0.2)(x)
# Second fully connected layer with 256 units
x = Dense(256)(x)
x = LeakyReLU(alpha=0.2)(x)
# Output layer with 1 unit and sigmoid activation for binary classification
x = Dense(1, activation='sigmoid')(x)
# Create the discriminator model
model = Model(input_img, x)
# Forward pass through the model
img = model(input_img)
# Return both the discriminator model and its output
return model, img
# Create discriminator model
# Call build_discriminator function to create discriminator model and its output
discriminator, disc_output = build_discriminator(mnist_shape)
discriminator.summary()
  1. Compile the generator and discriminator networks:

# Compile the generator (G) model
# Use the Adam optimizer with learning rate 0.0002 and momentum 0.5
# Use binary crossentropy as the loss function
G.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy')
# Compile the discriminator (D) model
# Use the Adam optimizer with learning rate 0.0002 and momentum 0.5
# Use binary crossentropy as the loss function and accuracy as the metric to monitor
D.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy', metrics=['accuracy'])
  1. Combine the generator and discriminator and build the GAN:

# Define input layer for noise
input = Input(shape=noise_shape)
# Generate image from noise using generator G
image = G(input)
# Freeze discriminator's weights during training of combined model
D.trainable = False
# Perform classification on generated image using discriminator D
image = D(image)
# Combine generator G and discriminator D into a single model
# This model takes noise as input and outputs the classification result of the generated image
D_G_model = Model(input, image)
# Compile the combined model
D_G_model.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy')
D_G_model.summary()
  1. Load the dataset, preprocess it, and train the model:

# Load the MNIST dataset and extract the training images (X_train)
(X_train, _), (_,_) = mnist.load_data()
X_train.shape
# Center and normalize the pixel values of the images
# Convert pixel values to the range [-1, 1] for improved training stability
X_train = (X_train.astype('float32') - 127.5) / 127.5
# Expand the dimensions of the training data to include a channel dimension (for convolutional layers)
X_train = np.expand_dims(X_train, axis=3)
# The mean and standard deviation of the training data
# Useful for verifying that the data is properly centered and normalized
print(np.mean(X_train), np.std(X_train))
# This noise will be used to generate fake images
noise_shape = (100,)
# Train the GAN
for epoch in range(epochs):
# Train the discriminator
# Sample real images from the training data
indices = np.random.randint(0, X_train.shape[0], half_batch)
images = X_train[indices]
# Train the discriminator on real images
d_real_loss = D.train_on_batch(images, np.ones((half_batch, 1)))
# Generate fake images using random noise as input to the generator
noise = np.random.uniform(0, 1, (half_batch, noise_shape[0]))
noise_images = G.predict(noise)
# Train the discriminator on fake images
d_fake_loss = D.train_on_batch(noise_images, np.zeros((half_batch, 1)))
# Compute the average discriminator loss
d_loss = np.add(d_real_loss, d_fake_loss) / 2
# Train the generator
# Generate noise for the full batch
noise = np.random.uniform(0, 1, (batch_size, noise_shape[0]))
# Train the combined model (generator + discriminator) on noise
g_loss = D_G_model.train_on_batch(noise, np.ones((batch_size, 1)))
if epoch % save_every == 0:
print('Epoch: {}, D_Loss:{}, D_Acc:{}, G_Loss:{}'.format(epoch, d_loss[0], d_loss[1], g_loss))
  1. Test the generator by passing random noise and checking what it generates:

noise = np.random.uniform(0, 1, (1, noise_shape[0]))
image = G.predict(noise)
# Visualise
plt.imshow(image[0,:,:, 0], cmap='gray')

We can view the results of the model by running the following live app:

The code won’t run here because we are not using a GPU, which is a requirement for training the vanilla GAN. However, the output can be observed in prerun mode in the Jupyter Notebook below.

Conclusion

Implementing Vanilla GAN from scratch provides a deep understanding of the underlying concepts of generative modeling. By following the steps outlined, we can embark on our journey into the fascinating world of GANs. Experimenting with different architectures, loss functions, and training strategies can further enhance our understanding of creating compelling generative models.

Free Resources

Copyright ©2025 Educative, Inc. All rights reserved