2018-01-07 Saves and reload a pytorch model#

Saving and loading a pytorch model is something which took me a couple of searches on internet so let’s write it out. So to save a model:

import torch
model = create_model()  # create model
train_model(model)      # something which trains a models
torch.save(model.state_dict(), filename)

To load it:

import torch
model = create_model()  # create model
model.load_state_dict(torch.load(filename))

These two snippets of code assume that there exists a function which creates the model structure. This comes from the fact pytorch allows custom functions and they cannot be serialized with pickle.