使用fashion-MNIST演示PyTorch实现卷积神经网络LeNet5的创建、训练和测试
导入依赖包
import torch
import torch.nn as nn
import torchvision
import torch.utils.data as Data
import torchvision.transforms as transforms
import sys
加载数据集
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=False, download=True, transform=transforms.ToTensor())
batch_size = 128
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4
train_iter = Data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_iter = Data.DataLoader(mnist_test,batch_size=batch_size,shuffle=Fa