pytorch的LSTM层的batch first参数
DataLoader返回数据时候一般第一维都是batch,pytorch的LSTM层默认输入和输出都是batch在第二维。如果按照默认的输入和输出结构,可能需要自己定义DataLoader的collate_fn函数,将batch放在第一维。我一开始就是费了一些劲,捣鼓了半天。后来发现有batch first这个参数,将其设为True就可以将batch放在第一维。(其实一开始看文档的时候注意到了,但是后来写代码忘记它了,回过头来看的时候简直要气死!!)还有就是使用这个参数的时候有一点要注意,看官方文档:
原创
2020-12-14 21:03:12 ·
10793 阅读 ·
6 评论