Saving and Loading Models
Learn how you can save and load models using PyTorch.
We'll cover the following
Why save models?
Training a model successfully is great. There’s no doubt about that, but not all models will be trained that fast and maybe training gets interrupted (computer crashing, timeout after 12h of continuous GPU usage on Google Colab, etc.). It would be a pity to have to start over, right?
So, it is important to be able to checkpoint or save our model; that is, saving it to disk in case we would like to restart training later or deploy it as an application to make predictions.
Model state
To checkpoint a model, we basically have to save its state to a file so that it can be loaded back later, nothing special, actually.
But, what actually defines the state of a model? The following provides an overview of this:
-
model.state_dict()
: Kinda obvious, right? -
optimizer.state_dict()
: Remember, optimizers have astate_dict
as well. -
Losses: After all, you should keep track of its evolution.
-
Epoch: It is just a number, so why not?
-
Anything else you’d like to have restored later.
Saving checkpoint
After defining the model state, we now have to wrap everything into a Python dictionary, and use torch.save()
to dump it all into a file. Easy peasy! We have just saved our model to a file named model_checkpoint.pth
.
Get hands-on with 1300+ tech skills courses.