本篇博客主要介绍采用RNN做MNIST数据集分类。
示例代码:
import torch
from torch import nn
from torch.autograd import Variable
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
# 超参数
EPOCH = 1
BATCH_SIZE = 64
TIME_STEP = 28 # rnn time step / image height
INPUT_SIZE = 28 # rnn input size / image width
LR = 0.01
DOWNLOWD_MNIST = False # 如果没有下载好MNIST数据,设置为True
# 下载数据
# 训练数据
train_data = datasets.MNIST(root='./mnist', train=True, transform=transforms.ToTensor(), download=DOWNLOWD_MNIST)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BAT