一、使用VGG模型进行猫狗大战
内容包含: 1.主要步骤解读 2. 代码 及解释 3.运行截图 4.知识点补充 5.遇到的问题及解决方法
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torchvision
from torchvision import models,transforms,datasets
import time
import json
# 判断是否存在GPU设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Using gpu: %s ' % torch.cuda.is_available())
Using gpu: False
1. 下载数据
Jeremy Howard 提供了数据的下载,链接为:http://files.fast.ai/data/dogscats.zip
在他整理的数据集中,猫和狗的图片放在单独的文件夹中, 同时还提供了一个Validation数据。如果没有GPU设备,请减少用做训练的图像数据量即可。
因为这个代码需要在colab上跑,速度会相对较慢。因此,我们重新整理了数据,制作了新的数据集,训练集包含1800张图(猫的图片900张,狗的图片900张),测试集包含2000张图。下载地址为:http://fenggao-image.stor.sinaapp.com/dogscats.zip
import wget
#引入基本的库
url='http://fenggao-image.stor.sinaapp.com/dogscats.zip'
file_name=wget.download(url)
print(file_name)
import zipfile
import os
src_path='dogscats.zip'
target_path='dogscats'
if(not os.path.isdir(target_path+"/Images")):
z = zipfile.ZipFile(src_path, 'r')
z.extractall(path=target_path)
z.close()
else:
print("文件已解压")
``
问题及解决方法
注意:这里要变化一下,如果直接用! wget +地址来下载的话,报错
修改成这个样子
OK
2. 数据处理
datasets 是 torchvision 中的一个包,可以用做加载图像数据。它可以以多线程(multi-thread)的形式从硬盘中读取数据,使用 mini-batch 的形式,在网络训练中向 GPU 输送。在使用CNN处理图像时,需要进行预处理。图片将被整理成 224 × 224 × 3 224\times 224 \times 3 224×224×3 的大小,同时还将进行归一化处理。
torchvision 支持对输入数据进行一些复杂的预处理/变换 (normalization, cropping, flipping, jittering 等)。
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
vgg_format = transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
data_dir = './dogscats'
dsets = {
x: datasets.ImageFolder(os.path.join(data_dir, x), vgg_format)
for x in ['train', 'valid']}
dset_sizes = {
x: len(dsets[x]) for x in ['train', 'valid']}
dset_classes = dsets['train'].classes
通过下面代码可以查看 dsets 的一些属性解析:
print(dsets[‘train’].classes)#分的文件夹的名字下的类别
print(dsets[‘train’].class_to_idx)#按顺序为这些类别定义索引为0,1,2,…
print(dsets[‘train’].imgs[:5])#返回从文件下中得到的图片的路径以及其类别
print('dset_sizes: ', dset_sizes)#返回该文件夹的内的图片数量
# 通过下面代码可以查看 dsets 的一些属性
print(dsets['train'].classes)
print(dsets['train'].class_to_idx)
print(dsets['train'].imgs[:5])
print('dset_sizes: ', dset_sizes)
运行结果:
使用DataLoader解析:
torch.utils.data.DataLoader主要是对数据进行batch的划分,
batch_size批训练集的大小
num_workers:多线程来读数据
shuffle:是否打乱数据
使用DataLoader的好处是,可以快速的迭代数据。
loader_train = torch.utils.data.DataLoader(dsets['train'], batch_size