前言
学习pytorch已经一周了,pytorch官网的示例代码基本上都敲了一遍,关于tensor的使用,数据集,网络定义等。和之前学习caffe痛苦的经历相比,pytorch对常用的操作都进行了封装,只要安装流程做即可。在之前的学习基础上,本章节内容将在CIFAR10数据集上训练一个简单的CNN网络。
目的
基于CIFAR-10数据集,训练一个简单CNN网络。
保存训练好的模型,测试
使用GPU训练
开发/实验环境
Ubuntu 18.04
pytorch 1.0
Anaconda3, python3.6
pycharm
CIFAR数据集
The CIFAR-10 and CIFAR-100 are labeled subsets of the 80 million tiny images dataset. They were collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton.
CIFAR数据集可分为CIFAR10, CIFAR100。 CIFAR-10是指包含10个种类, CIFAR-100包含100个种类。
CIFAR-10
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
特点:
32x32 彩色图像
10个类别
总共60000张图像
50000张训练样本 + 10000张测试样本
每个类别有6000张图像, 10 x 6000 = 60000
10个类别:
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
image.png
实验过程
准备数据集
这一步骤在pytorch中非常方便,pytorch已经为我们准备好了常见的数据集合,只需要导入即可。
数据集在torchvision.dataset 包里面
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
torchvision.dataset.CFAIR10 是一个类, 通过实例化该类的一个对象,就可以操作数据集。
参数:
root -----数据集下载后保存的路径
train-----训练or测试
download----是否需要自动下载
transform----对图像进行变换, 一般需要对原始图像进行ToTensor(), Normalize()变换
之后,使用DataLoader类对数据集进行包装,目的是为了方便读取和使用,比如可以min_batch读取, 采用多线程。
# --------------------准备数据集------------------
# Dataset, DataLoader
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), std =(0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data',train=False,
transform=transform, download=True)
trainloader = DataLoader(dataset=trainset, batch_size=4, shuffle=True, num_workers=4)
testloader = DataLoader(dataset=testset, batch_size=4, shuffle=True, num_workers=4)
#
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '