Pytorch lstm中batch_first 参数理解使用

  • 作为新手学习和使用Lstm处理数据时,多数会对这个参数有些疑惑,经过长时间的查找资料和个人理解,对这个参数有以下认识:

1、此参数的作用

 loader = Data.DataLoader(dataset=torch_data_set, batch_size=10, shuffle=True, num_workers=5, )
  • 很多同学使用Pytorch开发lstm处理时序问题时,一般都需要组织时序滑动窗口作为训练数据,选择DataLoader组织数据非常方便,但是会发现DataLoader组织出的数据第一个维度是batch_size, 大家都知道lstm要求的入参顺序为(sql_len , btach_size , input_size),因此,我们的主人公:batch_first 参数就派上用场了,只需要将batch_first = true,即可解决问题,继续正常的使用lstm模型训练你的数据。

2、参数的改变对训练速度的影响

1、虽然设置这个参数可以简单快速的解决问题,但是我们从训练速度角度观察,会发现设置batch_first=True后,多数情况下训练速度大幅下降,这是为什么呢?
2、原因是cuDNN中RNN的Api就是batch_size在第二个维度,这么设置的目的如下(以下举例非原创,摘抄自-知乎文章:读PyTorch源码学习RNN)
举个例子,假设输入序列的长度(seq_len)是3,batch_size是2,一个batch的数据是[[“A”, “B”, “C”], [“D”, “E”, “F”]],如图1所示。由于RNN是序列模型,只有t1时刻计算完成,才能进入t2时刻,而"batch"就体现在每个时刻ti的计算过程中,图1中 t1时刻将[“A”, “D”]作为当前时刻的batch数据,t2时刻将[“B”, “E”]作为当前时刻的batch数据,可想而知,“A"与"D"在内存中相邻比"A"与"B"相邻更合理,这样取数据时才更高效。而不论Tensor的维度是多少,在内存中都以一维数组的形式存储,batch first意味着Tensor在内存中存储时,先存储第一个sequence,再存储第二个… 而如果是seq_len first,模型的输入在内存中,先存储所有sequence的第一个元素,然后是第二个元素… 两种区别如图2所示,seq_len first意味着不同sequence中同一个时刻对应的输入元素(比如"A”, “D” )在内存中是毗邻的,这样可以快速读取数据。
图1 RNN示例
图2 batch first vs seq_len first

3、如何不改变RNN的默认入参顺序解决问题?

可以选择permute函数解决:

batch_x = batch_x.permute((1, 0, 2))

此方式将一个批次内的shape从例子中的 [2,3] 转置为[3, 2] ,可以在不影响训练速度的原则下解决问题

[参考资料]:https://zhuanlan.zhihu.com/p/32103001

  • 10
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值