Saving and Loading Methods
Learn how you can now save and load models as well as make predictions using class attributes.
We'll cover the following...
Saving and loading
Most of the code here is the same as the code we had in the chapter, Rethinking the Training Loop. The only difference is that we will be using the class attributes instead of the local variables.
The updated method for saving checkpoints should look like this now:
Press + to interact
def save_checkpoint(self, filename):# Builds dictionary with all elements for resuming trainingcheckpoint = {'epoch': self.total_epochs,'model_state_dict': self.model.state_dict(),'optimizer_state_dict': self.optimizer.state_dict(),'loss': self.losses,'val_loss': self.val_losses}torch.save(checkpoint, filename)setattr(StepByStep, 'save_checkpoint', save_checkpoint)
In addition, the loading checkpoint method should look like the following:
Press + to interact
def load_checkpoint(self, filename):# Loads dictionarycheckpoint = torch.load(filename)# Restore state for model and optimizerself.model.load_state_dict(checkpoint['model_state_dict'])self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])self.total_epochs = checkpoint['epoch']self.losses = checkpoint['loss']self.val_losses = checkpoint['val_loss']self.model.train() # always use TRAIN for resuming trainingsetattr(StepByStep, 'load_checkpoint', load_checkpoint)
Notice that the model is set to training mode after loading the checkpoint at line 15.
Access this course and 1400+ top-rated courses and projects.