Saving and Loading Models
Learn how you can save and load models using PyTorch.
We'll cover the following...
Why save models?
Training a model successfully is great. There’s no doubt about that, but not all models will be trained that fast and maybe training gets interrupted (computer crashing, timeout after 12h of continuous GPU usage on Google Colab, etc.). It would be a pity to have to start over, right?
So, it is important to be able to checkpoint or save our model; that is, saving it to disk in case we would like to restart training later or deploy it as an application to make predictions.
Model state
To checkpoint a model, we basically have to save its state to a file so that it can be loaded back later, nothing special, actually.
But, what actually defines the state of a model? The following provides an overview of this:
-
model.state_dict()
: Kinda obvious, right? -
optimizer.state_dict()
: Remember, optimizers have astate_dict
as well. -
Losses: After all, you should keep track of its evolution.
-
Epoch: ...