PyTorch入门(二)LeNet网络CIFAR10分类识别
利用PyTorch构建动态计算图的一般步骤
Load and nomoralize the training and test datasets
Define a Convolutional Neural Network
Difne a loss function
Train the network on training data
Test the netword on test data
LeNet进行CIFAR10分类Python代码
##Load and Normalizing CIFAR10
################################################################
import torch
import torchvision
import torchvision.transforms as transforms
IfDownLoad = False
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data',
train=True,
download=IfDownLoad,
transform=transform)
# 该参数是指在进行数据集加载时,启用的线程数目。截止当前2018年5月9日11:15:52,如官方未解决该BUG,则可以通过修改num_works参数为 0 ,只启用一个主进程加载数据集,避免在windows使用多线程即可。
trainloader = torch.utils.data.DataLoader(trainset,
batch_size=4,
shuffle=True,