[python] Best way to save a trained model in PyTorch?

I was looking for alternative ways to save a trained model in PyTorch. So far, I have found two alternatives.

  1. torch.save() to save a model and torch.load() to load a model.
  2. model.state_dict() to save a trained model and model.load_state_dict() to load the saved model.

I have come across to this discussion where approach 2 is recommended over approach 1.

My question is, why the second approach is preferred? Is it only because torch.nn modules have those two function and we are encouraged to use them?

This question is related to python serialization deep-learning pytorch tensor

The answer is


The pickle Python library implements binary protocols for serializing and de-serializing a Python object.

When you import torch (or when you use PyTorch) it will import pickle for you and you don't need to call pickle.dump() and pickle.load() directly, which are the methods to save and to load the object.

In fact, torch.save() and torch.load() will wrap pickle.dump() and pickle.load() for you.

A state_dict the other answer mentioned deserves just few more notes.

What state_dict do we have inside PyTorch? There are actually two state_dicts.

The PyTorch model is torch.nn.Module has model.parameters() call to get learnable parameters (w and b). These learnable parameters, once randomly set, will update over time as we learn. Learnable parameters are the first state_dict.

The second state_dict is the optimizer state dict. You recall that the optimizer is used to improve our learnable parameters. But the optimizer state_dict is fixed. Nothing to learn in there.

Because state_dict objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.

Let's create a super simple model to explain this:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

This code will output the following:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Note this is a minimal model. You may try to add stack of sequential

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm layers) have entries in the model's state_dict.

Non learnable things, belong to the optimizer object state_dict, which contains information about the optimizer's state, as well as the hyperparameters used.

The rest of the story is the same; in the inference phase (this is a phase when we use the model after training) for predicting; we do predict based on the parameters we learned. So for the inference, we just need to save the parameters model.state_dict().

torch.save(model.state_dict(), filepath)

And to use later model.load_state_dict(torch.load(filepath)) model.eval()

Note: Don't forget the last line model.eval() this is crucial after loading the model.

Also don't try to save torch.save(model.parameters(), filepath). The model.parameters() is just the generator object.

On the other side, torch.save(model, filepath) saves the model object itself, but keep in mind the model doesn't have the optimizer's state_dict. Check the other excellent answer by @Jadiel de Armas to save the optimizer's state dict.


A common PyTorch convention is to save models using either a .pt or .pth file extension.

Save/Load Entire Model Save:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Load:

Model class must be defined somewhere

model = torch.load(PATH)
model.eval()

It depends on what you want to do.

Case # 1: Save the model to use it yourself for inference: You save the model, you restore it, and then you change the model to evaluation mode. This is done because you usually have BatchNorm and Dropout layers that by default are in train mode on construction:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Case # 2: Save model to resume training later: If you need to keep training the model that you are about to save, you need to save more than just the model. You also need to save the state of the optimizer, epochs, score, etc. You would do it like this:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

To resume training you would do things like: state = torch.load(filepath), and then, to restore the state of each individual object, something like this:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Since you are resuming training, DO NOT call model.eval() once you restore the states when loading.

Case # 3: Model to be used by someone else with no access to your code: In Tensorflow you can create a .pb file that defines both the architecture and the weights of the model. This is very handy, specially when using Tensorflow serve. The equivalent way to do this in Pytorch would be:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

This way is still not bullet proof and since pytorch is still undergoing a lot of changes, I wouldn't recommend it.


If you want to save the model and wants to resume the training later:

Single GPU: Save:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Load:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

Multiple GPU: Save

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Load:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU

Examples related to python

programming a servo thru a barometer Is there a way to view two blocks of code from the same file simultaneously in Sublime Text? python variable NameError Why my regexp for hyphenated words doesn't work? Comparing a variable with a string python not working when redirecting from bash script is it possible to add colors to python output? Get Public URL for File - Google Cloud Storage - App Engine (Python) Real time face detection OpenCV, Python xlrd.biffh.XLRDError: Excel xlsx file; not supported Could not load dynamic library 'cudart64_101.dll' on tensorflow CPU-only installation

Examples related to serialization

laravel Unable to prepare route ... for serialization. Uses Closure TypeError: Object of type 'bytes' is not JSON serializable Best way to save a trained model in PyTorch? Convert Dictionary to JSON in Swift Java: JSON -> Protobuf & back conversion Understanding passport serialize deserialize How to generate serial version UID in Intellij Parcelable encountered IOException writing serializable object getactivity() Task not serializable: java.io.NotSerializableException when calling function outside closure only on classes not objects Cannot deserialize the JSON array (e.g. [1,2,3]) into type ' ' because type requires JSON object (e.g. {"name":"value"}) to deserialize correctly

Examples related to deep-learning

How to initialize weights in PyTorch? What is the use of verbose in Keras while validating the model? How to import keras from tf.keras in Tensorflow? Keras input explanation: input_shape, units, batch_size, dim, etc Pytorch reshape tensor dimension What is the role of "Flatten" in Keras? Best way to save a trained model in PyTorch? Update TensorFlow Why binary_crossentropy and categorical_crossentropy give different performances for the same problem? Keras, How to get the output of each layer?

Examples related to pytorch

Pytorch tensor to numpy array How to initialize weights in PyTorch? How to check if pytorch is using the GPU? PyTorch: How to get the shape of a Tensor as a list of int Pytorch reshape tensor dimension Best way to save a trained model in PyTorch? Model summary in pytorch How does the "view" method work in PyTorch?

Examples related to tensor

PyTorch: How to get the shape of a Tensor as a list of int Keras input explanation: input_shape, units, batch_size, dim, etc Pytorch reshape tensor dimension Best way to save a trained model in PyTorch? How does the "view" method work in PyTorch? How to get the dimensions of a tensor (in TensorFlow) at graph construction time? How to print the value of a Tensor object in TensorFlow?