一、网络模型的使用和修改
vgg16_true.add_module(‘add_linear’,nn.Linear(1000,10))
import torchvision
from torch import nn
from torchvision import models
# train_data = torchvision.datasets.ImageNet("./data_image_net",split='train',download=True,
# transform=torchvision.transforms.ToTensor())
#pytorch
vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_true = torchvision.models.vgg16(weights=models.VGG16_Weights.DEFAULT)
# print("OK")
# print(vgg16_true)
train_data = torchvision.datasets.CIFAR10("./data",train=True,transform=torchvision.transforms.ToTensor(),
download=True)
vgg16_true.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)
二、网络模型的保存和加载
(1)模型的保存
-
torch.save(vgg16_false,“vgg16_method1.pth”)
-
torch.save(vgg16_false.state_dict(),“vgg16_method2.pth”)
import torch
import torchvision
from torch import nn
from torchvision import models
# train_data = torchvision.datasets.ImageNet("./data_image_net",split='train',download=True,
# transform=torchvision.transforms.ToTensor())
#pytorch
vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_true = torchvision.models.vgg16(weights=models.VGG16_Weights.DEFAULT)
#保存方式1
torch.save(vgg16_false,"vgg16_method1.pth")
#保存方式2
torch.save(vgg16_false.state_dict(),"vgg16_method2.pth")
# print("OK")
# print(vgg16_true)
train_data = torchvision.datasets.CIFAR10("./data",train=True,transform=torchvision.transforms.ToTensor(),
download=True)
vgg16_true.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)
保存之后目录中会出现对应的pth文件,出现之后所保存的模型就可以用于之后的加载操作了。
(2)模型的加载
import torch
#方式1->保存方式1加载模型
import torchvision
model = torch.load("vgg16_method1.pth")
print("vgg16_false1:",model)
#方式2->保存方式2加载模型
vgg16_false = torchvision.models.vgg16(weights=None)
vgg16_false.load_state_dict(torch.load("vgg16_method2.pth"))
print("vgg16_false2:",vgg16_false)
显示模型(局部):