Generative Adversarial Networks (GANs) have become an increasingly famous topic in AI due to their ability to generate high-quality data across various domains, from images to music and beyond. Among the different GAN variants, vanilla GAN stands out as the fundamental architecture on which many other GANs are built. Here, we will explore the workings of vanilla GAN and implement it from scratch.
At its core, a vanilla GAN consists of two neural networks: a generator
The generator aims to generate synthetic data samples that resemble the real data
Here is a step-by-step implementation of vanilla GAN:
Import the necessary libraries for creating and visualizing the GAN:
import osimport numpy as npimport matplotlib.pyplot as plt%matplotlib inlineimport kerasfrom keras.models import Model, Sequentialfrom keras.layers import Dense, BatchNormalization, Reshape, Dropout, LeakyReLU, Input, Flattenfrom keras.optimizers import Adamfrom keras.datasets import mnistfrom keras.utils import plot_model
Define the required parameters for the model:
# Number of epochs for trainingepochs = 20000# Shape of the MNIST images (28x28 pixels with 1 channel)mnist_shape = (28, 28, 1)# Batch size for trainingbatch_size = 128# Shape of the noise input to the generator (100-dimensional vector)noise_shape = (100,)# Interval for saving generated images and models during trainingsave_every = 1000
Define and build the generator network:
def build_generator(noise_shape, mnist_shape):# Define input layer for the generator modelnoise = Input(shape=noise_shape)# Fully connected layer with 256 unitsx = Dense(256, input_shape=(noise_shape))(noise)x = LeakyReLU(alpha=0.2)(x)x = BatchNormalization(momentum=0.8)(x)# Fully connected layer with 512 unitsx = Dense(512)(x)x = LeakyReLU(alpha=0.2)(x)x = BatchNormalization(momentum=0.8)(x)# Fully connected layer with 1024 unitsx = Dense(1024)(x)x = LeakyReLU(alpha=0.2)(x)x = BatchNormalization(momentum=0.8)(x)# Output layer with units equal to the number of features in the output image# Use tanh activation to ensure pixel values are in the range [-1, 1]x = Dense(np.prod(mnist_shape), activation='tanh')(x)# Reshape output to match the desired image shapex = Reshape(mnist_shape)(x)# Create the generator modelmodel = Model(noise, x)# Forward pass through the modelimg = model(noise)# Return both the generator model and its outputreturn Model(noise, img)# Create generator model# Call build_generator function to create generator modelgenerator = build_generator(noise_shape, mnist_shape)generator.summary()
Define and build the discriminator network:
def build_discriminator(mnist_shape):# Define input layer for the discriminator modelinput_img = Input(shape=mnist_shape)# Flatten the input imagex = Flatten()(input_img)# First fully connected layer with 512 unitsx = Dense(512)(x)x = LeakyReLU(alpha=0.2)(x)# Second fully connected layer with 256 unitsx = Dense(256)(x)x = LeakyReLU(alpha=0.2)(x)# Output layer with 1 unit and sigmoid activation for binary classificationx = Dense(1, activation='sigmoid')(x)# Create the discriminator modelmodel = Model(input_img, x)# Forward pass through the modelimg = model(input_img)# Return both the discriminator model and its outputreturn model, img# Create discriminator model# Call build_discriminator function to create discriminator model and its outputdiscriminator, disc_output = build_discriminator(mnist_shape)discriminator.summary()
Compile the generator and discriminator networks:
# Compile the generator (G) model# Use the Adam optimizer with learning rate 0.0002 and momentum 0.5# Use binary crossentropy as the loss functionG.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy')# Compile the discriminator (D) model# Use the Adam optimizer with learning rate 0.0002 and momentum 0.5# Use binary crossentropy as the loss function and accuracy as the metric to monitorD.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy', metrics=['accuracy'])
Combine the generator and discriminator and build the GAN:
# Define input layer for noiseinput = Input(shape=noise_shape)# Generate image from noise using generator Gimage = G(input)# Freeze discriminator's weights during training of combined modelD.trainable = False# Perform classification on generated image using discriminator Dimage = D(image)# Combine generator G and discriminator D into a single model# This model takes noise as input and outputs the classification result of the generated imageD_G_model = Model(input, image)# Compile the combined modelD_G_model.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy')D_G_model.summary()
Load the dataset, preprocess it, and train the model:
# Load the MNIST dataset and extract the training images (X_train)(X_train, _), (_,_) = mnist.load_data()X_train.shape# Center and normalize the pixel values of the images# Convert pixel values to the range [-1, 1] for improved training stabilityX_train = (X_train.astype('float32') - 127.5) / 127.5# Expand the dimensions of the training data to include a channel dimension (for convolutional layers)X_train = np.expand_dims(X_train, axis=3)# The mean and standard deviation of the training data# Useful for verifying that the data is properly centered and normalizedprint(np.mean(X_train), np.std(X_train))# This noise will be used to generate fake imagesnoise_shape = (100,)# Train the GANfor epoch in range(epochs):# Train the discriminator# Sample real images from the training dataindices = np.random.randint(0, X_train.shape[0], half_batch)images = X_train[indices]# Train the discriminator on real imagesd_real_loss = D.train_on_batch(images, np.ones((half_batch, 1)))# Generate fake images using random noise as input to the generatornoise = np.random.uniform(0, 1, (half_batch, noise_shape[0]))noise_images = G.predict(noise)# Train the discriminator on fake imagesd_fake_loss = D.train_on_batch(noise_images, np.zeros((half_batch, 1)))# Compute the average discriminator lossd_loss = np.add(d_real_loss, d_fake_loss) / 2# Train the generator# Generate noise for the full batchnoise = np.random.uniform(0, 1, (batch_size, noise_shape[0]))# Train the combined model (generator + discriminator) on noiseg_loss = D_G_model.train_on_batch(noise, np.ones((batch_size, 1)))if epoch % save_every == 0:print('Epoch: {}, D_Loss:{}, D_Acc:{}, G_Loss:{}'.format(epoch, d_loss[0], d_loss[1], g_loss))
Test the generator by passing random noise and checking what it generates:
noise = np.random.uniform(0, 1, (1, noise_shape[0]))image = G.predict(noise)# Visualiseplt.imshow(image[0,:,:, 0], cmap='gray')
We can view the results of the model by running the following live app:
The code won’t run here because we are not using a GPU, which is a requirement for training the vanilla GAN. However, the output can be observed in prerun mode in the Jupyter Notebook below.
Implementing Vanilla GAN from scratch provides a deep understanding of the underlying concepts of generative modeling. By following the steps outlined, we can embark on our journey into the fascinating world of GANs. Experimenting with different architectures, loss functions, and training strategies can further enhance our understanding of creating compelling generative models.
Free Resources