pytorch里 RNN 网络多GPU的坑

简单的概述一下问题。

在使用pytorch的时候,网络中用到了RNN/GRU/LSTM网络,由于输入的一个batch里数据长度不一致,所以数据进行了补0操作来对齐。然后在送入RNN之前需要使用pack_padded_sequence将数据进行压缩,输入后,再使用pad_packed_sequence将结果还原到对齐的长度。这两个函数的具体使用和意义,网上一搜很多。
举个例子, 一个batch里数据原始长度分别为 [300,200,100,50],统一补齐为300后。
这三个操作pack sequence -> recurrent network -> unpack sequence之后,得到的output是batch_size * seq_length * output_size, 这里seq_length是这个最长的length,也就是300.

这里遇到的问题是,这样的网络送入多GPU,使用nn.DataParallel后会有一个问题。
咱们假设有2个GPU,这样每个GPU处理的数据batch_size=2, 也就是分成了[300,200] , [100,50] 两组数据,经过RNN之后,前面的那组数据按照最长的length数据,得到长度为300的outputs,而后面那组数据则会得到长度为100的outputs,本质上来说就是每个GPU处理的时候都是局部的看自己的batch数据的长度。

这样最后如果return这个outputs,多个GPU的结果会合并,但是长度不一致,就会出错。

解决方案是pytorch修正了这个问题,引入了一个参数,total_length,将这个最大的长度传入,则所有的GPU的outputs会统一长度。

My recurrent network doesn’t work with data parallelism
There is a subtlety in using the pack sequence -> recurrent network -> unpack sequence pattern in a Module with DataParallel or data_parallel(). Input to each the forward() on each device will only be part of the entire input. Because the unpack operation torch.nn.utils.rnn.pad_packed_sequence() by default only pads up to the longest input it sees, i.e., the longest on that particular device, size mismatches will happen when results are gathered together. Therefore, you can instead take advantage of the total_length argument of pad_packed_sequence() to make sure that the forward() calls return sequences of same length. For example, you can write:

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class MyModule(nn.Module):
    # ... __init__, other methods, etc.

    # padded_input is of shape [B x T x *] (batch_first mode) and contains
    # the sequences sorted by lengths
    #   B is the batch size
    #   T is max sequence length
    def forward(self, padded_input, input_lengths):
        total_length = padded_input.size(1)  # get the max sequence length
        packed_input = pack_padded_sequence(padded_input, input_lengths,
                                            batch_first=True)
        packed_output, _ = self.my_lstm(packed_input)
        output, _ = pad_packed_sequence(packed_output, batch_first=True,
                                        total_length=total_length)
        return output


m = MyModule().cuda()
dp_m = nn.DataParallel(m)

文章的最后,总结一下pytorch里使用RNN的一些注意事项。

  1. 不同length的数据需要做padding,补齐。补齐的代码网上找一找很多
  2. 注意,后面使用pack_padded_sequence等函数需要batch里的数据长度降序排列。利用Dataloader的collate_fn参数,参考https://blog.csdn.net/u011550545/article/details/89529977
  3. 三件套, 参考 https://zhuanlan.zhihu.com/p/34418001
    packed_input = pack_padded_sequence(padded_input, input_lengths,
                                        batch_first=True)
    packed_output, _ = self.my_lstm(packed_input)
    output, _ = pad_packed_sequence(packed_output, batch_first=True,
                                    total_length=total_length)
    
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值