The pickle Python library implements binary protocols for serializing and de-serializing a Python object.
When you import torch
(or when you use PyTorch) it will import pickle
for you and you don't need to call pickle.dump()
and pickle.load()
directly, which are the methods to save and to load the object.
In fact, torch.save()
and torch.load()
will wrap pickle.dump()
and pickle.load()
for you.
A state_dict
the other answer mentioned deserves just few more notes.
What state_dict
do we have inside PyTorch?
There are actually two state_dict
s.
The PyTorch model is torch.nn.Module
has model.parameters()
call to get learnable parameters (w and b).
These learnable parameters, once randomly set, will update over time as we learn.
Learnable parameters are the first state_dict
.
The second state_dict
is the optimizer state dict. You recall that the optimizer is used to improve our learnable parameters. But the optimizer state_dict
is fixed. Nothing to learn in there.
Because state_dict
objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.
Let's create a super simple model to explain this:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
This code will output the following:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
Note this is a minimal model. You may try to add stack of sequential
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm layers) have entries in the model's state_dict
.
Non learnable things, belong to the optimizer object state_dict
, which contains information about the optimizer's state, as well as the hyperparameters used.
The rest of the story is the same; in the inference phase (this is a phase when we use the model after training) for predicting; we do predict based on the parameters we learned. So for the inference, we just need to save the parameters model.state_dict()
.
torch.save(model.state_dict(), filepath)
And to use later model.load_state_dict(torch.load(filepath)) model.eval()
Note: Don't forget the last line model.eval()
this is crucial after loading the model.
Also don't try to save torch.save(model.parameters(), filepath)
. The model.parameters()
is just the generator object.
On the other side, torch.save(model, filepath)
saves the model object itself, but keep in mind the model doesn't have the optimizer's state_dict
. Check the other excellent answer by @Jadiel de Armas to save the optimizer's state dict.