Saving and Loading Methods

Learn how you can now save and load models as well as make predictions using class attributes.

We'll cover the following...

Saving and loading

Most of the code here is the same as the code we had in the chapter, Rethinking the Training Loop. The only difference is that we will be using the class attributes instead of the local variables.

The updated method for saving checkpoints should look like this now:

Press + to interact
def save_checkpoint(self, filename):
# Builds dictionary with all elements for resuming training
checkpoint = {
'epoch': self.total_epochs,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': self.losses,
'val_loss': self.val_losses
}
torch.save(checkpoint, filename)
setattr(StepByStep, 'save_checkpoint', save_checkpoint)

In addition, the loading checkpoint method should look like the following:

Press + to interact
def load_checkpoint(self, filename):
# Loads dictionary
checkpoint = torch.load(filename)
# Restore state for model and optimizer
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(
checkpoint['optimizer_state_dict']
)
self.total_epochs = checkpoint['epoch']
self.losses = checkpoint['loss']
self.val_losses = checkpoint['val_loss']
self.model.train() # always use TRAIN for resuming training
setattr(StepByStep, 'load_checkpoint', load_checkpoint)

Notice that the model is set to training mode after loading the checkpoint at line 15.

Access this course and 1400+ top-rated courses and projects.