...

/

Training the 01110 Pattern GAN

Training the 01110 Pattern GAN

Learn how we can now train the 01110 pattern GAN.

The code is similar to the code we used for training the simple 1010 pattern here.

Training the 01110 pattern GAN

We’re ready to train the GAN using the 3 step training loop. Have a look at the following code:

Press + to interact
main.py
Generator.py
Discriminator.py
# import libraries
import torch
import torch.nn as nn
import pandas
import matplotlib.pyplot as plt
import random
import numpy
from Discriminator import Discriminator
from Generator import Generator
def generate_real():
real_data = torch.FloatTensor(
[random.uniform(0.0, 0.2),
random.uniform(0.8, 1.0),
random.uniform(0.8, 1.0),
random.uniform(0.8, 1.0),
random.uniform(0.0, 0.2)])
return real_data
def generate_random(size):
random_data = torch.rand(size)
return random_data
# create Discriminator and Generator
D = Discriminator()
G = Generator()
# define an image list to store how output evolves during training
image_list = []
# train Discriminator and Generator
for i in range(10000):
# train discriminator on true
D.train(generate_real(), torch.FloatTensor([1.0]))
# train discriminator on false
# use detach() so gradients in G are not calculated
D.train(G.forward(torch.FloatTensor([0.5])).detach(), torch.FloatTensor([0.0]))
# train generator
G.train(D, torch.FloatTensor([0.5]), torch.FloatTensor([1.0]))
# add image to list every 1000
if (i % 1000 == 0):
image_list.append( G.forward(torch.FloatTensor([0.5])).detach().numpy() )
pass
# manually run generator to see it's outputs
print("Generator's output after training:", G.forward(torch.FloatTensor([0.5])))
# plot images collected during training
plt.figure(figsize = (16,8))
plt.imshow(numpy.array(image_list).T, interpolation='none', cmap='Blues')
plt.savefig('output/legend.png')

Discriminator and generator objects

We first create fresh discriminator and generator objects, before running the training loop 10,000 times.

Training loop

Let’s examine the main.py file. Inside ...