Pytorch 中的 模型的使用

1. Pytorch 中的模型

 包括:语音识别、文本处理、图像识别等

 2. torchvision

The following classification models are available, with or without pre-trained weights:

下述可用的分类模型:

1. VGG 的使用

torchvision.models.vgg16是一个在PyTorch中实现的VGG-16模型,用于图像分类任务。它是在ImageNet数据集上预训练过的,并通过将全连接层替换为适合特定任务的新全连接层来进行微调。

参数介绍如下:

  • pretrained: 一个bool值,表示是否使用在ImageNet上预训练的权重。默认为False。
  • progress: 一个bool值,表示在下载预训练权重时是否显示下载进度。默认为True。

当实例化VGG-16模型时,可以设置这些参数来指定是否使用预训练的权重以及是否显示下载进度。

 1. vgg模型参数的下载

"""
    pretrained=False
    progress=True
"""
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

# 查看 vgg16 的模型结构
print(vgg16_true)

2. vgg模型的添加与修改

# 在 该处添加一个 线性层
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))
print(vgg16_true)

# 修改 原有的 模型结构
print(vgg16_false)
vgg1 = vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)

3. vgg模型的保存与加载

方式一:

注意:

        1. 保存模型的文件后缀名一般为:.pth

        2. 该种 模型 可以直接使用与加载

# 保存方式一
torch.save(vgg16_true,"vgg_saveModel1.pth")
# 加载方式一
model1 = torch.load("vgg_saveModel1.pth")
print(model1)

方式二:

 保存方式二  : 将 vgg 中的参数,保存为 字典形式。 ---》 官方推荐,所占空间更小

                最后通过 vgg16_noParam.load_state_dict() 将参数加载至模型之中!!

# 保存方式二  : 将 vgg 中的参数,保存为 字典形式。 ---》 官方推荐,所占空间更小
torch.save(vgg16_true.state_dict(),"vgg_saveModel2.pth")
# 加载方式二
model_params = torch.load("vgg_saveModel2.pth")
print(model_params)
vgg16_noParam = torchvision.models.vgg16(pretrained=False)
vgg = vgg16_noParam.load_state_dict(model_params)
print(vgg)

注意 : 方式 二 的使用,更加的适用!!!!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值