数据集与先前下载的数据集一致,所以此次就不再下载,但仍然呈现了下载的代码
代码如下,大部分都加了详细的注释
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
EPOCH=1 #训练整批数据的次数
BATCH_SIZE = 64#批训练的数据个数
TIME_STEP=28#考虑多少时间点的数据,
INPUT_SIZE=28#每个时间点给RNN多少个数据点
LR = 0.01 # 学习率
DOWNLOAD_MNIST = False
train_data=torchvision.datasets.MNIST(#下载数据的代码
root='./mnist',
train=True,
transform=torchvision.transforms.ToTensor(), #(网上数据改为tensor),0-1之间,并复制到train_data中
download=DOWNLOAD_MNIST#没有下载就=true,下载了就用false
)
train_loader=Data.DataLoader(dataset=train_