import torch
from torch import nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# torch.manual_seed(1) # reproducible
# 定义一些参数
EPOCH = 1 # 训练次数
BATCH_SIZE = 64 # 一次训练的数据量,可以理解为有多少条句子
TIME_STEP = 28 # 可以理解为一个句子的序列长度
INPUT_SIZE = 28 # 可以理解为每个词向量的维度,也就是输入维度,假如是3,那就是3
LR = 0.01 # learning rate
DOWNLOAD_MNIST = False # set to True if haven't download the data
# 定义数据集
train_data = dsets.MNIST(
root='./mnist/',
train=True, # this is training data
transform=transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
down