pytorch深度学习:RNN循环神经网络(二)

上一节了解到了RNN和LSTM的基础知识,这节我们使用LSTM网络对mnist手写数字数据集进行处理。重点了解LSTM网络的搭建方法以及各个参数所代表的含义。

1.问题的提出

我们知道RNN一般是用来处理具有时间序列的数据,但是mnist数据集是图像数据,怎么进行处理呢?这里我们将mnist手写字数据的每行当作不同时间点的数据,也就是一张图片是由28个时间序列数据组成,即Time_step=28,每个时间序列数据包含28个像素点,也就是Input_size=28。第一步,我们来创建一些超参数:

# 定义一些超参数
EPOCH = 1           # 训练整批数据多少次, 为了节约时间, 我们只训练一次
BATCH_SIZE = 64
TIME_STEP = 28      # rnn 时间步数 / 图片高度
INPUT_SIZE = 28     # rnn 每步输入值 / 图片每行像素
LR = 0.01           # learning rate

数据也是和前边提到的一样,分为测试集和训练集,方法这里不再进行展示。

2.LSTM网络的搭建

在pytorch中已经有封装好的LSTM模块,我们直接调用nn.LSTM()来搭建,具体方法如下:
我们创建了只有一个LSTM单元的RNN网络,其后接一个全连接网络进行10分类。

class RNN(nn.Module):
    def __init__(self):
        super(RNN,self).__init__()
        self.rnn = nn.LSTM(
            input_size=INPUT_SIZE,
            hidden_size=64,
            num_layers=1,
            batch_first=True
        )
        self.fc1 = nn.Linear(64,10)

    def forward(self,x):
        # x shape (batch, time_step, input_size)
        # r_out shape (batch, time_step, output_size)
        # h_n shape (n_layers, batch, hidden_size)   LSTM 有两个 hidden states, h_n 是分线, h_c 是主线
        # h_c shape (n_layers, batch, hidden_size)
        r_out, (h_n, h_c) = self.rnn(x, None)   # None 表示 hidden state 会用全0的 state

        # 选取最后一个时间点的r_out 输出
        # 这里 r_out[:, -1, :] 的值也就是最后一时刻 h_n (主线)的值
        out = self.fc1(r_out[:, -1, :]
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值