CV入门--VGG16迁移学习(猫狗分类)实战

😁大家好,我是CuddleSabe,目前大四在读,深圳准入职算法工程师,研究主要方向为多模态(VQA、ImageCaptioning等),欢迎各位佬来讨论!
🍭我最近在有序地计划整理CV入门实战系列NLP入门实战系列。在这两个专栏中,我将会带领大家一步步进行经典网络算法的实现,欢迎各位读者(da lao)订阅🍀

迁移学习

迁移学习即在当前任务的数据量过小时,希望通过另一个数据集来预训练模型,希望通过这样学习到泛化性较强的特征,从而在当前目标任务上进行微调以获得不错的效果(一般做法为将特征提取部分进行参数冻结,只微调分类层)。
迁移学习在大体主要分为两种:即源域特征与目标域特征的特征域是否一致。
在这里插入图片描述
例子如上图所示。
本教程为入门教程,这里只探讨从ImageNet(1000类)到猫狗二分类的简单迁移学习,迁移学习后面会出一篇专门介绍。

VGG网络

VGG网络在是2014年ILSVRC竞赛的第二名,第一名是GoogLeNet。但是VGG模型在多个迁移学习任务中的表现要优于googLeNet。而且,从图像中提取CNN特征,VGG模型是首选算法。它的缺点在于,参数量有140M之多,需要更大的存储空间。但是这个模型很有研究价值。
VGG有如下图所示六种模型,这里我们使用的是最常见的VGG-D,即VGG16
在这里插入图片描述

思路介绍

在该教程中,我们希望通过在ImageNet数据集上训练好的模型来迁移到猫狗分类数据中,即我们需要进行以下步骤:

  • 下载VGG预训练模型
  • 修改最后分类层的参数(1000类改为2类)
  • 冻结前面的特征提取层(设置参数的requires_grad=False
  • 在猫狗数据集进行微调

代码编写

1.数据集划分

import pandas as pd
import os
from tqdm import tqdm
dir_path = './data/train'
dirs = os.listdir(dir_path)
tmp_path, tmp_label = [], []
for path in dirs:
    label = 1 if path.split('.')[0]=='dog' else 0
    tmp_label.append(label)
    tmp_path.append(os.path.join(dir_path, path))
csv_data = {'path': tmp_path, 'label':tmp_label}
data = pd.DataFrame(csv_data)
data.to_csv('./data/data.csv')
data = data.sample(frac=1)
train_data = data[:int(len(data)*0.85)]
val_data = data[int(len(data)*0.85):]
print('共划分训练集样本{}个,验证集样本{}个'.format(len(train_data), len(val_data)))

请添加图片描述

2.读取预训练模型

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from PIL import Image
import cv2
import numpy as np
vgg16 = models.vgg16(pretrained=True)

更改Vgg16模型的最后一层

vgg16.classifier[-1] = nn.Linear(4096, 2, bias=True)
vgg16.classifier

请添加图片描述

将特征提取层的参数冻结,即我们只需微调最后分类层的参数

for param in vgg16.features.parameters():
    param.requires_grad = False

3. 数据集编写

class VggDataset(Dataset):
    def __init__(self, data):
        super(VggDataset, self).__init__()
        self.path = data['path']
        self.label = data['label']
        self.len = len(self.path)
        self.transform = transforms.Compose([
            transforms.Resize((224,224)),# 重置图像分辨率
            transforms.ToTensor(),# 转化为张量并归一化至[0-1]
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
    def __getitem__(self, index):
        img = self.transform(Image.open(self.path.iloc[index]))
        return img, torch.LongTensor([self.label.iloc[index]]).squeeze()
    
    def __len__(self):
        return self.len
BATCH_SIZE = 64
EPOCHS = 10
lr = 0.001
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
train_set = VggDataset(train_data)
val_set = VggDataset(val_data)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)

4. 训练

loss_stack = []
val_acc = []
optimizer = torch.optim.Adam(vgg16.parameters(), lr=lr)
loss_fn = torch.nn.CrossEntropyLoss()
vgg16 = vgg16.to(device)

for epoch in tqdm(range(EPOCHS)):
    vgg16.train()
    all_loss = 0
    for data, labels in train_loader:
        data, labels = data.to(device), labels.to(device)
        preds = vgg16(data)
        loss = loss_fn(preds, labels)
        all_loss += loss.detach().cpu().numpy()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    loss_stack.append(all_loss)
            
    
    vgg16.eval()
    count = 0
    data_num = 0
    for data, labels in val_loader:
        data, labels = data.to(device), labels.to(device)
        preds = vgg16(data)
        pred_labels = torch.argmax(preds, dim=1)
        for i, p in enumerate(pred_labels):
            if p == labels[i]:
                count += 1
        data_num += data.shape[0]
    val_acc.append(count / data_num)    

请添加图片描述

5. 查看效果

%config InlineBackend.figure_format = 'retina'
%matplotlib inline
import matplotlib.pyplot as plt

plt.subplot(1, 2, 1)
plt.plot(val_acc, c='red')
plt.subplot(1, 2, 2)
plt.plot(loss_stack, c='red')
plt.show()

请添加图片描述

def predict(img_path, model, device):
    transform = transforms.Compose([
            transforms.Resize((224,224)),# 重置图像分辨率
            transforms.ToTensor(),# 转化为张量并归一化至[0-1]
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
    img = Image.open(img_path)
    img = transform(img).unsqueeze(0).to(device)
    model.eval()
    pred = model(img)
    label = torch.argmax(pred, dim=1)
    if label.data == 0:
        return 'cat'
    else:
        return 'dog'
img = Image.open('./data/test.jpg')
display(img)

在这里插入图片描述

predict('./data/test.jpg', vgg16, device)

请添加图片描述

预测成功

数据集及代码分享

链接: https://pan.baidu.com/s/19qRXeq-utdFY-qCdU4rflg 提取码: brqf

  • 7
    点赞
  • 55
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 14
    评论
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CuddleSabe

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值