2018-01-07 Saves and reload a pytorch model#
Saving and loading a pytorch model is something which took me a couple of searches on internet so let’s write it out. So to save a model:
import torch
model = create_model() # create model
train_model(model) # something which trains a models
torch.save(model.state_dict(), filename)
To load it:
import torch
model = create_model() # create model
model.load_state_dict(torch.load(filename))
These two snippets of code assume that there exists a function which creates the model structure. This comes from the fact pytorch allows custom functions and they cannot be serialized with pickle.