pytorchgpu测试_pytorch学习(十)—训练并测试CNN网络

前言

学习pytorch已经一周了,pytorch官网的示例代码基本上都敲了一遍,关于tensor的使用,数据集,网络定义等。和之前学习caffe痛苦的经历相比,pytorch对常用的操作都进行了封装,只要安装流程做即可。在之前的学习基础上,本章节内容将在CIFAR10数据集上训练一个简单的CNN网络。

目的

基于CIFAR-10数据集,训练一个简单CNN网络。

保存训练好的模型,测试

使用GPU训练

开发/实验环境

Ubuntu 18.04

pytorch 1.0

Anaconda3, python3.6

pycharm

CIFAR数据集

The CIFAR-10 and CIFAR-100 are labeled subsets of the 80 million tiny images dataset. They were collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton.

CIFAR数据集可分为CIFAR10, CIFAR100。 CIFAR-10是指包含10个种类, CIFAR-100包含100个种类。

CIFAR-10

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

特点:

32x32 彩色图像

10个类别

总共60000张图像

50000张训练样本 + 10000张测试样本

每个类别有6000张图像, 10 x 6000 = 60000

10个类别:

airplane

automobile

bird

cat

deer

dog

frog

horse

ship

truck

image.png

实验过程

准备数据集

这一步骤在pytorch中非常方便,pytorch已经为我们准备好了常见的数据集合,只需要导入即可。

数据集在torchvision.dataset 包里面

import torch

import torchvision

import torchvision.transforms as transforms

from torch.utils.data import DataLoader

import torch.nn as nn

import torch.nn.functional as F

import torch.optim as optim

import matplotlib.pyplot as plt

import numpy as np

torchvision.dataset.CFAIR10 是一个类, 通过实例化该类的一个对象,就可以操作数据集。

参数:

root -----数据集下载后保存的路径

train-----训练or测试

download----是否需要自动下载

transform----对图像进行变换, 一般需要对原始图像进行ToTensor(), Normalize()变换

之后,使用DataLoader类对数据集进行包装,目的是为了方便读取和使用,比如可以min_batch读取, 采用多线程。

# --------------------准备数据集------------------

# Dataset, DataLoader

transform = transforms.Compose(

[transforms.ToTensor(),

transforms.Normalize((0.5, 0.5, 0.5), std =(0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,

download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data',train=False,

transform=transform, download=True)

trainloader = DataLoader(dataset=trainset, batch_size=4, shuffle=True, num_workers=4)

testloader = DataLoader(dataset=testset, batch_size=4, shuffle=True, num_workers=4)

#

dataiter = iter(trainloader)

images, labels = dataiter.next()

imshow(torchvision.utils.make_grid(images))

# print labels

print(' '

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值