构建一个简单的LSTM网络对mnist数据集进行识别
class RNN_Net(nn.Module): #定义一个神经网络,作为识别分类
def __init__(self,in_dim,hidden_dim,n_layer,n_class=10):
super(RNN_Net, self).__init__() #初始化模型
self.n_layer=n_layer #构造第一层网络
self.hidden_dim=hidden_dim#构造第二层网络
self.lstm=nn.LSTM(in_dim,hidden_dim,n_layer,batch_first=True)
self.classifier=nn.Linear(hidden_dim,n_class)#构造第三层网络,即输出层
def forward(self,x): #前向传播过程
out,_=self.lstm(x)
out=out[:,-1,:]
out=self.classifier(out)
return out