Long Expressive Memory for Sequence Modeling(LEM模型)模型代码与结构图

LEM

文章前提:该文章主要根据github的代码画出来的计算流程图

github地址:GitHub - tk-rusch/LEM: Official code for Long Expressive Memory (ICLR 2022, Spotlight)

原文中的公式

在这里插入图片描述
在这里插入图片描述

其中u是输入,y,z是隐藏状态(据说对应LSTM中的c与h),𝜎是tanh函数,ˆ𝜎是sigmoid函数

LEM-cell 循环块

在这里插入图片描述

上图中的𝜎是sigmoid函数。顺带一提,代码中使用LEM模型进行预测或分类,通常使用y变量(而不是z变量)

具体代码实现

代码实现时其实没那么复杂

class LEMCell(nn.Module):
    def __init__(self, ninp, nhid, dt):
        super(LEMCell, self).__init__()
        self.ninp = ninp
        self.nhid = nhid
        self.dt = dt
        self.inp2hid = nn.Linear(ninp, 4 * nhid)#一次性对x进行权重转换,后续分割
        self.hid2hid = nn.Linear(nhid, 3 * nhid)#一次性对y进行权重转换,后续分割
        self.transform_z = nn.Linear(nhid, nhid)
        self.reset_parameters()

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.nhid)
        for w in self.parameters():
            w.data.uniform_(-std, std)

    def forward(self, x, y, z):#x:(batch_size,c),y:(batch_size,hidden_size),z:(batch_size,hidden_size)
        transformed_inp = self.inp2hid(x)
        transformed_hid = self.hid2hid(y)
        i_dt1, i_dt2, i_z, i_y = transformed_inp.chunk(4, 1)#x分割
        h_dt1, h_dt2, h_y = transformed_hid.chunk(3, 1)#y分割

        ms_dt_bar = self.dt * torch.sigmoid(i_dt1 + h_dt1)
        ms_dt = self.dt * torch.sigmoid(i_dt2 + h_dt2)

        z = (1.-ms_dt) * z + ms_dt * torch.tanh(i_y + h_y)
        y = (1.-ms_dt_bar)* y + ms_dt_bar * torch.tanh(self.transform_z(z)+i_z)

        return y, z

针对模型的个人理解

其实本人并没有对原论文中的数学推导都理解(半懂不懂的。。。),论文中说LEM在某种情况下与LSTM等同,我没看出来。。。但是根据我对其计算流程的观察,比起LSTM,我认为它更像GRU模型,具体相似度各位不用太在意,这种计算流程是否相较于LSTM与GRU更加适用于长序列(减弱梯度爆炸消失问题),从实验结果上看好像有用,后续如果遇到长序列课题或项目时可以试试。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值