Training the 01110 Pattern GAN
Learn how we can now train the 01110 pattern GAN.
We'll cover the following...
The code is similar to the code we used for training the simple 1010 pattern here.
Training the 01110 pattern GAN
We’re ready to train the GAN using the 3 step training loop. Have a look at the following code:
Press + to interact
main.py
Generator.py
Discriminator.py
# import librariesimport torchimport torch.nn as nnimport pandasimport matplotlib.pyplot as pltimport randomimport numpyfrom Discriminator import Discriminatorfrom Generator import Generatordef generate_real():real_data = torch.FloatTensor([random.uniform(0.0, 0.2),random.uniform(0.8, 1.0),random.uniform(0.8, 1.0),random.uniform(0.8, 1.0),random.uniform(0.0, 0.2)])return real_datadef generate_random(size):random_data = torch.rand(size)return random_data# create Discriminator and GeneratorD = Discriminator()G = Generator()# define an image list to store how output evolves during trainingimage_list = []# train Discriminator and Generatorfor i in range(10000):# train discriminator on trueD.train(generate_real(), torch.FloatTensor([1.0]))# train discriminator on false# use detach() so gradients in G are not calculatedD.train(G.forward(torch.FloatTensor([0.5])).detach(), torch.FloatTensor([0.0]))# train generatorG.train(D, torch.FloatTensor([0.5]), torch.FloatTensor([1.0]))# add image to list every 1000if (i % 1000 == 0):image_list.append( G.forward(torch.FloatTensor([0.5])).detach().numpy() )pass# manually run generator to see it's outputsprint("Generator's output after training:", G.forward(torch.FloatTensor([0.5])))# plot images collected during trainingplt.figure(figsize = (16,8))plt.imshow(numpy.array(image_list).T, interpolation='none', cmap='Blues')plt.savefig('output/legend.png')
Discriminator and generator objects
We first create fresh discriminator and generator objects, before running the training loop 10,000 times.
Training loop
Let’s examine the main.py
file.
Inside ...