对于LSTM,我们要处理的数据是一个序列数据,对于图片而言,我们如何将其转换成序列数据呢?图片的大小是28x28,所以我们可以将其看成长度为28的序列,序列中的每个数据的维度是28,这样我们就可以将其变成一个序列数据了。
model
class Rnn(nn.Module):
def __init__(self, in_dim, hidden_dim, n_layer, n_class):
super(Rnn, self).__init__()
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer,
batch_first=True)
self.classifier = nn.Linear(hidden_dim, n_class)
def forward(self, x):
# h0 = Variable(torch.z