方法一:使用torchvision.transforms中的transforms和torch.utils.data中的DataLoader
简介:这段代码是MNIST手写体识别中的部分代码。
#此篇代码为MNIST手写体识别中的代码
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
#定义一些超参数,只列举train_batch和test_batch
train_batch_size = 64
test_batch_size = 128
#下载数据并对数据进行预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])
#下载数据
train_dataset = mnist.MNIST('./data',train=True,transform = transform,download=True)
test_dataset = mnist.MNIST('./data',train=False,transform = transform)
#创建DataLoader
train_loader = DataLoader(train_dataset,batch_size = train_batch_size,shuffle=True)
test_loader = DataLoader(test_dataset,batch_size = test_batch_size,shuffle=True)
参数解释:
- transforms.N