1. Pytorch 中的模型
包括:语音识别、文本处理、图像识别等
2. torchvision
The following classification models are available, with or without pre-trained weights:
下述可用的分类模型:
1. VGG 的使用
torchvision.models.vgg16是一个在PyTorch中实现的VGG-16模型,用于图像分类任务。它是在ImageNet数据集上预训练过的,并通过将全连接层替换为适合特定任务的新全连接层来进行微调。
参数介绍如下:
- pretrained: 一个bool值,表示是否使用在ImageNet上预训练的权重。默认为False。
- progress: 一个bool值,表示在下载预训练权重时是否显示下载进度。默认为True。
当实例化VGG-16模型时,可以设置这些参数来指定是否使用预训练的权重以及是否显示下载进度。
1. vgg模型参数的下载
"""
pretrained=False
progress=True
"""
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 查看 vgg16 的模型结构
print(vgg16_true)
2. vgg模型的添加与修改
# 在 该处添加一个 线性层
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))
print(vgg16_true)
# 修改 原有的 模型结构
print(vgg16_false)
vgg1 = vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)
3. vgg模型的保存与加载
方式一:
注意:
1. 保存模型的文件后缀名一般为:.pth
2. 该种 模型 可以直接使用与加载
# 保存方式一
torch.save(vgg16_true,"vgg_saveModel1.pth")
# 加载方式一
model1 = torch.load("vgg_saveModel1.pth")
print(model1)
方式二:
保存方式二 : 将 vgg 中的参数,保存为 字典形式。 ---》 官方推荐,所占空间更小
最后通过 vgg16_noParam.load_state_dict() 将参数加载至模型之中!!
# 保存方式二 : 将 vgg 中的参数,保存为 字典形式。 ---》 官方推荐,所占空间更小
torch.save(vgg16_true.state_dict(),"vgg_saveModel2.pth")
# 加载方式二
model_params = torch.load("vgg_saveModel2.pth")
print(model_params)
vgg16_noParam = torchvision.models.vgg16(pretrained=False)
vgg = vgg16_noParam.load_state_dict(model_params)
print(vgg)
注意 : 方式 二 的使用,更加的适用!!!!