...

/

Solution Review: Define the Generator

Solution Review: Define the Generator

Learn to define the network architecture, loss function, and optimiser for the Generator class.

We'll cover the following...

Solution

Press + to interact
main.py
Generator.py
Discriminator.py
Dataset.py
fashion-mnist_train.csv
import torch
import torch.nn as nn
from torch.utils.data import Dataset
# generator class
class Generator(nn.Module):
def __init__(self):
# initialise parent pytorch class
super().__init__()
# define neural network layers
self.model = nn.Sequential(
nn.Linear(100+10, 200),
nn.LeakyReLU(0.02),
nn.LayerNorm(200),
nn.Linear(200, 784),
nn.Sigmoid()
)
# create optimiser, simple stochastic gradient descent
self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)
# counter and accumulator for progress
self.counter = 0;
self.progress = []
pass
def forward(self, seed_tensor, label_tensor):
# combine seed and label
inputs = torch.cat((seed_tensor, label_tensor))
return self.model(inputs)
def train(self, D, inputs, label_tensor, targets):
# calculate the output of the network
g_output = self.forward(inputs, label_tensor)
# pass onto Discriminator
d_output = D.forward(g_output, label_tensor)
# calculate error
loss = D.loss_function(d_output, targets)
# increase counter and accumulate error every 10
self.counter += 1;
if (self.counter % 10 == 0):
self.progress.append(loss.item())
pass
# zero gradients, perform a backward pass, update weights
self.optimiser.zero_grad()
loss.backward()
self.optimiser.step()
pass
def plot_images(self, label):
label_tensor = torch.zeros((10))
label_tensor[label] = 1.0
# plot a 3 column, 2 row array of sample images
f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):
for j in range(3):
axarr[i,j].imshow(G.forward(generate_random_seed(100), label_tensor).detach().cpu().numpy().reshape(28,28), interpolation='none', cmap='Blues')
pass
pass
pass
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
pass
pass

Explanation

  • Lines 13-21 ...