网络模型的保存于与读取
方法1:
1.1 如何保存网络模型
首先,创建一个py文件,model_save.py
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
torch.save(vgg16,"vgg16_model1_pth")
运行结束后我们会在我们左侧的文件出现vgg16_model1_pth这个文件
用这种方法保存,不仅保存了网络模型,也保存了网络模型中的相关参数

1.2 如何读取网络模型
新建一个py文件,model_load.py
import torch
model = torch.load("vgg16_model1_pth")
print(model)
输出:
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
方法二
2.1:如何保存网络模型
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
#torch.save(vgg16,"vgg16_model1_pth")
torch.save(vgg16.state_dict(),"vgg16_model2_pth")
也会在左侧形成一个vgg16_model2_pth文件
只保存了模型的参数,占用空间更小,官方推荐方式
2.2:如何读取网络模型
读取方式与方法一 一样,但是输出为字典类型的数据
import torch
# 方式2-> 保存方式2,加载模型
model = torch.load("vgg16_model2_pth") # 加载出来的是字典类型的数据
print(model)
F:\Anaconda3\envs\pytorch\python.exe D:/Python/learn_torch/model_load.py
OrderedDict([('features.0.weight', tensor([[[[ 3.9726e-02, -4.0263e-02, 5.2152e-02],
[ 3.5984e-02, -4.6239e-02, -2.4924e-02],
[-9.6867e-03, 1.2961e-02, -4.5731e-02]],
[[ 1.9925e-03, 3.6464e-02, 5.6411e-02],
[-9.0956e-02, -3.6801e-02, -7.3917e-02],
[ 3.6363e-02, -4.5585e-02, -8.2003e-03]],
[[-1.1151e-01, -2.4694e-02, -3.4446e-02],
[-5.4018e-02, 7.9030e-02, 1.1468e-01],
[ 6.1839e-02, -8.7451e-02, 2.8596e-03]]],
[[[-6.4775e-02, 5.2936e-03, -1.8106e-02],
[-4.0254e-02, -8.5685e-02, -7.8011e-02],
[ 1.1739e-02, -7.9629e-02, 6.6174e-02]],
[[-1.1657e-01, 3.5422e-02, 6.2663e-02],
[ 3.0534e-02, 6.9120e-03, 3.3340e-03],
[-1.5356e-01, 7.2058e-02, 4.7606e-02]],
[[-1.2942e-01, -3.5475e-02, 9.7374e-02],
[-1.3898e-02, -2.5312e-02, 6.3060e-02],
[ 5.4231e-04, 1.4181e-02, 8.3530e-02]]],
[[[-1.5726e-03, 6.0129e-02, -2.5256e-02],
[-8.2932e-02, 9.2577e-02, 1.8457e-02],
[-5.7204e-02, -5.2296e-02, 8.6386e-02]],
[[-3.1392e-02, 1.2295e-01, -6.2096e-03],
[-1.6034e-02, 3.0497e-03, 5.9402e-02],
[-7.5480e-02, -6.9659e-02, -1.2263e-02]],
[[ 6.5706e-05, -4.6442e-02, 6.1466e-02],
[ 3.6150e-02, 3.6947e-02, -9.4802e-02],
[ 7.0997e-02, 1.2181e-02, 3.3660e-03]]],
.....................................
....................................
..................................
从上述输出结果中得到的结果是字典类型,其中参数的值也一起输出来了,如果想要查看具体的网络结构,需要这样
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_model2_pth")) # 输出完整的模型结构,与第一种方式输出的模型结构相同
print(vgg16)

1997

被折叠的 条评论
为什么被折叠?



