上一节了解到了RNN和LSTM的基础知识,这节我们使用LSTM网络对mnist手写数字数据集进行处理。重点了解LSTM网络的搭建方法以及各个参数所代表的含义。
1.问题的提出
我们知道RNN一般是用来处理具有时间序列的数据,但是mnist数据集是图像数据,怎么进行处理呢?这里我们将mnist手写字数据的每行当作不同时间点的数据,也就是一张图片是由28个时间序列数据组成,即Time_step=28,每个时间序列数据包含28个像素点,也就是Input_size=28。第一步,我们来创建一些超参数:
# 定义一些超参数
EPOCH = 1 # 训练整批数据多少次, 为了节约时间, 我们只训练一次
BATCH_SIZE = 64
TIME_STEP = 28 # rnn 时间步数 / 图片高度
INPUT_SIZE = 28 # rnn 每步输入值 / 图片每行像素
LR = 0.01 # learning rate
数据也是和前边提到的一样,分为测试集和训练集,方法这里不再进行展示。
2.LSTM网络的搭建
在pytorch中已经有封装好的LSTM模块,我们直接调用nn.LSTM()
来搭建,具体方法如下:
我们创建了只有一个LSTM单元的RNN网络,其后接一个全连接网络进行10分类。
class RNN(nn.Module):
def __init__(self):
super(RNN,self).__init__()
self.rnn = nn.LSTM(
input_size=INPUT_SIZE,
hidden_size=64,
num_layers=1,
batch_first=True
)
self.fc1 = nn.Linear(64,10)
def forward(self,x):
# x shape (batch, time_step, input_size)
# r_out shape (batch, time_step, output_size)
# h_n shape (n_layers, batch, hidden_size) LSTM 有两个 hidden states, h_n 是分线, h_c 是主线
# h_c shape (n_layers, batch, hidden_size)
r_out, (h_n, h_c) = self.rnn(x, None) # None 表示 hidden state 会用全0的 state
# 选取最后一个时间点的r_out 输出
# 这里 r_out[:, -1, :] 的值也就是最后一时刻 h_n (主线)的值
out = self.fc1(r_out[:, -1, :]