VGG模型实现“猫狗大战挑战赛”
1 大赛简介
2013年,Kaggle举办过一个很受欢迎的猫狗识别竞赛(Dogs vs. Cats),比赛内容是识别图像中的是猫还是狗。
其中AI研习社平台的猫狗大战赛可以每天提交测试结果,挑战赛内容以及数据集都在此。
本博客是基于 github 上 summitgao 老师所编写的建立的VGG模型基础上进行优化改进,并对关键部分进行解读。
2 运行环境及准备
谷歌的 Colab 平台提供免费的云端 GPU ,本篇博客的代码均在Colab上运行。
2.1 数据下载
- 如下Shell获取数据集并解压在Google云端
!wget http://fenggao-image.stor.sinaapp.com/dogscats.zip
!unzip dogscats.zip
!wget https://static.leiphone.com/cat_dog.rar
!unrar x /content/cat_dog.rar
- zip文件为 summitgao 老师制作的新的数据集,训练集包含1800张图。
- rar文件为大赛原数据集,训练集包含20000张图。
cat_dog数据集目录:
- test:最后通过训练好的模型来识别的测试图片。
- train:用来训练模型的图片。
- val:文件夹下的图片有确定的标签,用来测试模型训练效果。
- dogscats 数据集同理
由于代码需要在 Colab 上运行,速度会相对较慢。因此,本次使用 dogscats 数据集进行训练。
2.2 引入Pytorch包及启用GPU模式
- 启用GPU:
在 Colab 编辑器界面:修改—笔记本设置—硬件加速器—GPU
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torchvision
from torchvision import models,transforms,datasets
import time
import json
# 判断是否存在GPU设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Using gpu: %s ' % torch.cuda.is_available())
2.3 数据预处理
- 完成数据集的准备后,需要对数据进行一些预处理:
- 图片将被整理成 224 × 224 × 3 的大小,同时还将进行归一化处理。
- 其他的一些对数据的复杂的预处理/变换 (normalization, cropping, flipping, jittering 等)可以参照 torchvision.tranforms 的官方文档说明。
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
vgg_format = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
data_dir = '/content/dogscats'
dsets = {
x: datasets.ImageFolder(os.path.join(data_dir, x), vgg_format)
for x in ['train', 'valid']}
dset_sizes = {
x: len(dsets[x]) for x in ['train', 'valid']}
dset_classes = dsets