代码如下:
# 1.定义超参数
BATCH_SIZE = 16
# 2.构建pipeline,对图片做一些变换
pipeline = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))
])
# 3.加载数据集
from torch.utils.data import DataLoader
# 下载数据集
train_set = datasets.MNIST("data",train=True,download=True,transform=pipeline)
test_set = datasets.MNIST("data",train=False,download=True,transform=pipeline)
# 加载数据
train_loader = DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)
test_loader = DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
这样就会自动下载数据集,下载的数据集是在你的代码所在的文件夹下面,一个叫data的文件夹,那里面就是下载的图片。
以上来源于从B站优秀的up主所讲解的视频中整理,B站链接为:
https://www.bilibili.com/video/BV1WT4y177SA?t=416