pytorch的LSTM层的batch first参数

DataLoader返回数据时候一般第一维都是batch,pytorch的LSTM层默认输入和输出都是batch在第二维。在这里插入图片描述
如果按照默认的输入和输出结构,可能需要自己定义DataLoader的collate_fn函数,将batch放在第一维。

我一开始就是费了一些劲,捣鼓了半天。后来发现有batch first这个参数,将其设为True就可以将batch放在第一维。(其实一开始看文档的时候注意到了,但是后来写代码忘记它了,回过头来看的时候简直要气死!!)

还有就是使用这个参数的时候有一点要注意,看官方文档:
在这里插入图片描述
设置batch first为true后,input和output都会变为batch在第一维,但是我们有时候也会用到hn和cn,那它们两个是会变呢还是不变呢?
作为懒星人,先去百度了一下,有一篇博客是这样说的:
在这里插入图片描述
所以我在写代码时就按照博客所说的来写了,但是报错了。。。。
只能自己上手实验了。


```python
import torch.nn as nn
import torch
import numpy as np

model = nn.LSTM(input_size=6, hidden_size=10, num_layers=1, batch_first=True)
model = model.double()

x = np.random.randn(100, 10, 6)

x = torch.from_numpy(x)
print(x.shape)

y, (hn, cn) = model(x)  # 不提供h0和c0,默认全0
print('y:', y.shape)
print('hn:', hn.shape)
print('cn:', cn.shape)

运行结果:

在这里插入图片描述

根据运行结果来看,设置batch first为true,只有输入input和输出output的batch会在第一维,hn和cn是不会变的。使用的时候要注意,会很容易弄混。
还有就是,这里并没有提供h0和c0,如果需要提供h0和c0,也需要注意shape。

  • 27
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值