LSTMCell

什么是LSTMCell

今天在回顾Seq2Seq利用Attention注意力机制实现的时候,发现decoder中用到的不是普通的LSTM而是LSTMCell,那么它到底什怎么回事?和LSTM又有哪些区别呢?以及在Seq2Seq中起到了什么作用?让我们一探究竟!

LSTMCell含义

avatar

如图是一个RNN按时间步的展开图,RNNCell就相当于一个时间步的处理。

同理,LSTMCell是LSTM的一个单元,LSTMCell就相当于一个时间步的处理。

LSTMCell类

class LSTMCell(RNNCellBase):
    def __init__(self, input_size: int, hidden_size: int, bias: bool = True,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)

    def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
        if hx is None:
            zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
            hx = (zeros, zeros)
        return _VF.lstm_cell(
            input, hx,
            self.weight_ih, self.weight_hh,
            self.bias_ih, self.bias_hh,
        )

和LSTM相比,LSTMCell参数中没有num_layers(层数)、bidirectional(双向)、dropout选项。

官方文档还提供一个用来同等实现LSTMCell作用的例子

Examples::

        >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
        >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
        >>> hx = torch.randn(3, 20) # (batch, hidden_size)
        >>> cx = torch.randn(3, 20)
        >>> output = []
        >>> for i in range(input.size()[0]):
                hx, cx = rnn(input[i], (hx, cx))
                output.append(hx)
        >>> output = torch.stack(output, dim=0)

可能类的源代码我们还一知半解,不太清楚到底和LSTM有什么区别,下面的这个例子就很好地说明了一切。

其实它就是LSTM操作,只不过每一次执行完LSTMCell后,我们都执行了一步,而有多少个时间步,我们就需要执行多少个LSTMCell。

为什么要用LSTMCell

未完待续~

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

365JHWZGo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值