LSTM长短期记忆网络,听说用来识别手写数据集有点大材小做,不过一直对语音、序列模型等等没怎么实践过,以后有空了再玩玩。
跑了一个epoch,正确率为97。
完整代码:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable
input_size = 28
sequence_length = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 1
learning_rate = 0.01
train_datasets = dsets.MNIST(root='./data',
download=False,
train=True,
transform=transforms.ToTensor())
test_datasets = dsets.MNIST(root='./data',
download=False,