pytorch模型保存的两种方法
假设实例化的模型为model,类为Model,Path表示模型的路径
- 如果模型中的参数随着程序的运行而变化可使用
# 保存
torch.save(model ,Path)
# 加载
# 这里model不用初始化为model = Model(),但是一定要首先引入Model类
import Model # 如果没有这个类的化,下面的语句会报错!
model = torch.load(Path)
- 如果模型中的参数数量随着程序运行不发生变化,则可使用:
# 保存
torch.save(model.state_dict(),Path)
# 加载
model = Model()
# 这里注意加载的模型参数数量要与Model()初始化的模型参数数量要一置,不然会报错
model.load_dict_state(torch.load(Path))
下面我举一个参数会发生变化的模型,其中1,2是相互配套的代码
import torch
import torch.nn as nn
class A(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.register_buffer("aaa", torch.tensor([]))
def forward(self,x):
self.aaa.data = torch.cat((self.aaa,torch.tensor([x])),0)
b = A() # ----1
b.load_state_dict(torch.load('./fucc.pt')) # ----1
# b = torch.load('./fucc.pth') # ----2
for i in range (20):
b(i)
torch.save(b.state_dict(),'fucc.pt') # ----1
# torch.save(b,'fucc.pth') # ----2