Update the Neural Network Code
Look at the complete code for the neural network.
We'll cover the following...
The final code
We’ve now worked out how to prepare the inputs for training and querying and the outputs for training.
Let’s update our Python code to include this work. Here’s the code we’ve developed so far:
Press + to interact
import numpy# scipy.special for the sigmoid function expit()import scipy.special# library for plotting arraysimport matplotlib.pyplot# neural network class definitionclass neuralNetwork:# initialise the neural networkdef __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):# set number of nodes in each input, hidden, output layerself.inodes = inputnodesself.hnodes = hiddennodesself.onodes = outputnodes# link weight matrices, wih and who# weights inside the arrays are w_i_j, where link is from node i to node j in the next layer# w11 w21# w12 w22 etcself.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))# learning rateself.lr = learningrate# activation function is the sigmoid functionself.activation_function = lambda x: scipy.special.expit(x)pass# train the neural networkdef train(self, inputs_list, targets_list):# convert inputs list to 2d arrayinputs = numpy.array(inputs_list, ndmin=2).Ttargets = numpy.array(targets_list, ndmin=2).T# calculate signals into hidden layerhidden_inputs = numpy.dot(self.wih, inputs)# calculate the signals emerging from hidden layerhidden_outputs = self.activation_function(hidden_inputs)# calculate signals into final output layerfinal_inputs = numpy.dot(self.who, hidden_outputs)# calculate the signals emerging from final output layerfinal_outputs = self.activation_function(final_inputs)# output layer error is the (target - actual)output_errors = targets - final_outputs# hidden layer error is the output_errors, split by weights, recombined at hidden nodeshidden_errors = numpy.dot(self.who.T, output_errors)# update the weights for the links between the hidden and output layersself.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs)), numpy.transpose(hidden_outputs))# update the weights for the links between the input and hidden layersself.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), numpy.transpose(inputs))pass# query the neural networkdef query(self, inputs_list):# convert inputs list to 2d arrayinputs = numpy.array(inputs_list, ndmin=2).T# calculate signals into hidden layerhidden_inputs = numpy.dot(self.wih, inputs)# calculate the signals emerging from hidden layerhidden_outputs = self.activation_function(hidden_inputs)# calculate signals into final output layerfinal_inputs = numpy.dot(self.who, hidden_outputs)# calculate the signals emerging from final output layerfinal_outputs = self.activation_function(final_inputs)return final_outputs# number of input, hidden and output nodesinput_nodes = 784hidden_nodes = 100output_nodes = 10# learning rate is 0.3learning_rate = 0.3# create instance of neural networkn = neuralNetwork(input_nodes,hidden_nodes,output_nodes, learning_rate)# load the mnist training data CSV file into a listtraining_data_file = open("mnist_train_100.csv", 'r')training_data_list = training_data_file.readlines()training_data_file.close()# train the neural network# go through all records in the training data setfor record in training_data_list:# split the record by the ',' commasall_values = record.split(',')# scale and shift the inputsinputs = (numpy.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01# create the target output values (all 0.01, except the desired label which is 0.99)targets = numpy.zeros(output_nodes) + 0.01# all_values[0] is the target label for this recordtargets[int(all_values[0])] = 0.99n.train(inputs, targets)pass
We can see we’ve imported the plotting library at the top, added some code to set the size of the input, hidden, and output layers, read the smaller MNIST training dataset, and then trained the neural network with those records. ...