pytorch 重写lstm 使用mask

本文详细介绍了如何在PyTorch中从零开始重写LSTM模块,并结合mask处理变长输入序列的问题,适用于自然语言处理等场景。
摘要由CSDN通过智能技术生成
    @staticmethod
    def _forward_rnn(cell, input, masks, initial, drop_masks):
        max_time = input.size(0) # seq_len:41
        output = []
        hx = initial # ([32,200], [32,200])  初始化值全为0
        for time in range(max_time):
            h_next, c_next = cell(input=input[time], hx=hx)  # input[time]为[32,100](为batch里面一个位置上所以词), 经过一个lstmcell输出的h_n和c_n  都为(32,200)
            h_next = h_next*masks[time] + initial[0]*(1-masks[time]) #  masks(41,32,200),masks[time]为一个(32,200)值为0或者1,
            c_next = c_next*masks[time] + initial[1]*(1-masks[time])  # 这里后面半句不应该一直都是0吗?
            output.append(h_next) # 0-40每个位置上的输出
            if drop_masks is not None: h_next = h_next * drop_masks
            hx = (h_next, c_next) # 把上一个h_n,c_n作为参数传入下一个lstmcell
        output = torch.stack(output, 0) # 把列表连接成(41,32,200)
        return output, hx # 返回的结果是每一列输出h_n连接成的(41,32,200)&
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值