pytorch对CIFAR10进行分类
最近实验室的项目需要用到pytorch,pytorch也是目前最流行的深度学习框架之一,非常简单好用,这次就使用pytorch搭建一个简单的卷积神经网络用于CIFAR10的分类任务。
对于一个有监督的分类任务,大概的训练过程是:
加载数据,包括测试数据和训练数据,数据包括数据本身和标签
定义一个网络,其中包括卷积层,池化层,全连接层,激活层等等
将输入数据送入网络
得到网络的输出,根据预测和真实计算损失
将损失进行反向传播,得到网络各阶段的梯度
使用优化函数根据反向传播的损失梯度对网络各层的权重进行更新
在搭建网络之前首先确认你已经成功安装了pytorch环境–
下载数据集
对于计算机视觉,pytorch中有一个包专门用来加载和处理数据–torchvision
,想要加载CIFAR10的数据,非常的简单。
import pytorch
import torchvision
import torchvision. transforms as transforms
transform = transforms. Compose(
[ transforms. ToTensor( ) ,
transforms. Normalize( ( 0.5 , 0.5 , 0.5 ) , ( 0.5 , 0.5 , 0.5 ) ) ] )
trainset = torchvision. datasets. CIFAR10( root= './data' , train= True , download= True , transform= transform)
trainloader = torch. utils. data. DataLoader( trainset, batch_size= 4 , shuffle= True )
testset = torchvision. datasets. CIFAR10( root= './data' , train= False , download= True , transform= transform)
testloader = torch. utils. data. DataLoader( testset, batch_size= 4 , shuffle= False )
classes = ( 'plane' , 'car' , 'bird' , 'cat' , 'deer' , 'dog' , 'frog' , 'horse' , 'ship' , 'truck' )
运行上面的代码,等待时间取决于你的网络速度,你将会看到这样的结果:
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
下载好的数据集就保存在当前文件夹的data文件夹下,是一个压缩文件,我们可以对其中的部分图片进行可视化
import matplotlib. pyplot as plt
import numpy as np
def imshow ( img) :
img = img / 2 + 0.5
nping = img. numpy( )
plt. imshow( np. transpose( nping, ( 1 , 2 , 0 ) ) )
plt. show( )
dataiter = iter ( trainloader)
images, labels = dataiter. next ( )
imshow( torchvision. utils. make_grid( images) )
print ( ' '