Rethinking the Training Loop
Learn how you can reduce the boilerplate section from the training loop by using higher-order functions.
Training step
As already mentioned, the higher-order function that builds a training step function for us is taking the key elements of our training loop: model, loss, and optimizer. The actual training step function to be returned will have two arguments, namely, features and labels, and will return the corresponding loss value.
Creating the higher-order function for training step
Apart from returning the loss value, the inner perform_train_step()
function below is the same as the code inside the loop in model training V0. The code should look like this:
def make_train_step(model, loss_fn, optimizer):# Builds function that performs a step in the train loopdef perform_train_step(x, y):# Sets model to TRAIN modemodel.train()# Step 1 - computes model's predictions - forward passyhat = model(x)# Step 2 - computes the lossloss = loss_fn(yhat, y)# Step 3 - computes gradients for "b" and "w" parametersloss.backward()# Step 4 - updates parameters using gradients and# the learning rateoptimizer.step()optimizer.zero_grad()# Returns the lossreturn loss.item()# Returns the function that will be called inside the# train loopreturn perform_train_step
Updating model configuration code
Then, we need to update our model configuration code to call this higher-order function to build a train_step
function. But we need to run the data preparation script first.
%run -i data_preparation/v0.py
The code for the configured model would look like the following:
%%writefile model_configuration/v1.pydevice = 'cuda' if torch.cuda.is_available() else 'cpu'# Sets learning rate - this is "eta" ~ the "n" like Greek letterlr = 0.1torch.manual_seed(42)# Now we can create a model and send it at once to the devicemodel = nn.Sequential(nn.Linear(1, 1)).to(device)# Defines a SGD optimizer to update the parametersoptimizer = optim.SGD(model.parameters(), lr=lr)# Defines a MSE loss functionloss_fn = nn.MSELoss(reduction='mean')# Creates the train_step function for our model, loss function# and optimizertrain_step = make_train_step(model, loss_fn, optimizer) # 1)
...