模型的保存(torch.save)
方式1(模型结构+模型参数)
参数:保存位置
# 创建模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1——模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")
方式2(模型参数)
# 保存方式2 模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
模型的加载(torch.load)
对应保存方式1
参数:模型路径
# 方式1 --》 保存方式1
model1 = torch.load("vgg16_method1.pth")
对应保存方式2
vgg16.load_state_dict("vgg16_method2.pth")
输出为字典形式。若要回复网络,采用以下形式:
model2 = torch.load("vgg16_method2.pth") #输出是字典形式
# 恢复网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)
方式1存储,加载时需注意事项
新建自己的网络:
class test(nn.Module):
def __init__(self):
super(lh, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
保存自己的网络:
Test = test()
# 保存自己定义的网络
torch.save(Test, "Test_method1.pth")
加载自己的网络:
model3 = torch.load("Test_method1.pth")
会报错!!!!!!
解决办法(需要注意):
将定义的网络复制到加载的python文件中:
class test(nn.Module):
def __init__(self):
super(test, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
model3 = torch.load("Test_method1.pth")