【入门教程】使用预训练模型进行训练、预测(以VGG16为例)

【新手教程】使用预训练模型进行训练、预测(以VGG16为例)

本文参考[csdn博文]( Pytorch学习笔记(I)——预训练模型(一):加载与使用_lockonlxf的博客-CSDN博客_pytorch使用预训练模型),修改了一些小问题

本文环境:win10、torch>=1.6

本文所有相关代码:阿里云盘

1、基础知识

VGG16是一个简单的深度学习模型,可以实现图像的分类。PyTorch的库中有VGG16的模型构架,在torchvision.models中:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
	......
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

(C,W,H)格式输入,输入RGB图像,通过(features)和(avgpool)得到一个(512,7,7)的特征图,将特征图输入到分类器中,通过线性化等一系列操作输出一个维度为1000的特征向量,对应1000个类别,其值可以简单理解为对应各个类别的可能性,通过值大小来判断图像类别。

2、具体实现

1 模型的加载与修改

本项目例子是猫狗分类,即给一张图片判断是猫片还是狗片,对应只有2个类别,所以需要把VGG分类器的最后一层输出改为2,具体实现为:

model = orchvision.models.vgg16(pretrained=True) # 加载torch原本的vgg16模型,设置pretrained=True,即使用预训练模型
num_fc = model.classifier[6].in_features # 获取最后一层的输入维度
model.classifier[6] = torch.nn.Linear(num_fc, num_cls)# 修改最后一层的输出维度,即分类数
# 对于模型的每个权重,使其不进行反向传播,即固定参数
for param in model.parameters():
    param.requires_grad = False
# 将分类器的最后层输出维度换成了num_cls,这一层需要重新学习
for param in model.classifier[6].parameters():
    param.requires_grad = True

修改完之后可以直接print(model)查看模型结构:

VGG(
  ......
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=2, bias=True)
  )
)

可以看到分类器最后的out_features=2

2 模型训练

1 数据准备

本文使用的是torch自带的ImageFolder进行数据读取,需要注意的是:读取的文件夹必须在一个大的子文件下,按类别归好类。示例数据集整理如图:

cat、dog即为类别名称,训练集和测试集都需要保持一样的命名。读取数据代码如下:

def dataload(trainData, testData):
    # 训练数据
    train_data = torchvision.datasets.ImageFolder(trainData, transform=transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ]))
    train_loader = DataLoader
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值