- 作为新手学习和使用Lstm处理数据时,多数会对这个参数有些疑惑,经过长时间的查找资料和个人理解,对这个参数有以下认识:
1、此参数的作用
loader = Data.DataLoader(dataset=torch_data_set, batch_size=10, shuffle=True, num_workers=5, )
- 很多同学使用Pytorch开发lstm处理时序问题时,一般都需要组织时序滑动窗口作为训练数据,选择DataLoader组织数据非常方便,但是会发现DataLoader组织出的数据第一个维度是batch_size, 大家都知道lstm要求的入参顺序为(sql_len , btach_size , input_size),因此,我们的主人公:batch_first 参数就派上用场了,只需要将batch_first = true,即可解决问题,继续正常的使用lstm模型训练你的数据。
2、参数的改变对训练速度的影响
1、虽然设置这个参数可以简单快速的解决问题,但是我们从训练速度角度观察,会发现设置batch_first=True后,多数情况下训练速度大幅下降,这是为什么呢?
2、原因是cuDNN中RNN的Api就是batch_size在第二个维度,这么设置的目的如下(以下举例非原创,摘抄自-知乎文章:读PyTorch源码学习RNN)
举个例子,假设输入序列的长度(seq_len)是3,batch_size是2,一个batch的数据是[[“A”, “B”, “C”], [“D”, “E”, “F”]],如图1所示。由于RNN是序列模型,只有t1时刻计算完成,才能进入t2时刻,而"batch"就体现在每个时刻ti的计算过程中,图1中 t1时刻将[“A”, “D”]作为当前时刻的batch数据,t2时刻将[“B”, “E”]作为当前时刻的batch数据,可想而知,“A"与"D"在内存中相邻比"A"与"B"相邻更合理,这样取数据时才更高效。而不论Tensor的维度是多少,在内存中都以一维数组的形式存储,batch first意味着Tensor在内存中存储时,先存储第一个sequence,再存储第二个… 而如果是seq_len first,模型的输入在内存中,先存储所有sequence的第一个元素,然后是第二个元素… 两种区别如图2所示,seq_len first意味着不同sequence中同一个时刻对应的输入元素(比如"A”, “D” )在内存中是毗邻的,这样可以快速读取数据。
3、如何不改变RNN的默认入参顺序解决问题?
可以选择permute函数解决:
batch_x = batch_x.permute((1, 0, 2))
此方式将一个批次内的shape从例子中的 [2,3] 转置为[3, 2] ,可以在不影响训练速度的原则下解决问题
[参考资料]:https://zhuanlan.zhihu.com/p/32103001