一、保存和读取参数
1、当训练完后,把当前的参数保存下来
import torch
torch.save(net.state_dict(), path)
保存参数只需用到torch.save(),其中net为自定义的模型名称,其子参数state_dict()为模型的参数,path为保存的路径加名称,其后缀为 pt 或 pth ,如: ‘pth/net_parameters.pth’。
2、加载参数
import torch
net.load_state_dict(torch.load(path))
二、保存和读取模型
保存和读取模型是把模型的网络架构以及其参数都进行保存和读取
import torch
torch.save(net, path) # 保存模型
net_ = torch.load(pth) # 读取模型
同样地, net为自定义的模型, 而net_为新加载的模型,path为路径和保存模型的名称,后缀为 pt 或 pth。