本文描述pytorch保存和加载模型的两种方法。参考:https://pytorch.org/docs/stable/notes/serialization.html
1 仅保存和加载模型参数
只保存模型参数,所占空间较少,但是加载时必须先定义好网络模型,并且加载时的网络模型和保存时的必须一模一样。
保存:
torch.save(the_model.state_dict(), PATH)
加载:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
2 保存和加载整个模型
保存整个模型,包括模型参数和网络结构,加载时不用先定义网络结构,直接加载就可以。
保存:
torch.save(the_model, PATH)
加载:
the_model = torch.load(PATH)