Pytorch LSTM网络初始化hidden维度报错:RuntimeError: Expected hidden[0] size (2, 14, 150), got [2, 64, 150]

如下为报错信息

Traceback (most recent call last):
  File "main.py", line 41, in <module>
    loss = models(x, size, y).abs()
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/kaggle/working/model.py", line 236, in forward
    emissions = self.bi_lstm_forward(sentence, sentence_lengths)
  File "/kaggle/working/model.py", line 227, in bi_lstm_forward
    lstm_out, self.hidden = self.lstm(embeds, hidden)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 759, in forward
    self.check_forward_args(input, hx, batch_sizes)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 686, in check_forward_args
    'Expected hidden[0] size {}, got {}')
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 226, in check_hidden_size
    raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
RuntimeError: Expected hidden[0] size (2, 14, 150), got [2, 64, 150]

大致意思为预期的hidden维度为[2,14,150],实际为[2, 64, 150]

如下为模型代码中获取hidden状态的方法及参数

'''
config.py
'''
batch_size = 64
epochs = 50
embedding_dim = 300
hidden_dim = 300
'''
model.py
'''
def get_state(self):
    c0_encoder = torch.zeros(2, config.batch_size, self.hidden_dim // 2)
    ### * self.num_directions = 2 if bi
    h0_encoder = torch.zeros(2, config.batch_size, self.hidden_dim // 2)
    h0_encoder = h0_encoder.to(config.device)
    c0_encoder = c0_encoder.to(config.device)
    return (h0_encoder, c0_encoder)

我的代码中初始化hidden维度始终是[2, batch_size, hidden_dim//2],但我查阅资料发现按batch取数据时,并不是都是设定好的batch_size,当剩余数据少于batch_size时,就直接把批次大小设为剩余量,比如本次报错中,剩余数据只有14条了,因此这一个batch的数据是14条,而不是64(batch_size)条,此时的hidden维度应该为[2, 14, 150]

修改代码

可以每次将输入的batch数据作为参数传入,动态获取这一batch数据的size

def get_state(self, input):
	batch_size = input.size(0)
    c0_encoder = torch.zeros(2, batch_size, self.hidden_dim // 2)
    ### * self.num_directions = 2 if bi
    h0_encoder = torch.zeros(2, batch_size, self.hidden_dim // 2)
    h0_encoder = h0_encoder.to(config.device)
    c0_encoder = c0_encoder.to(config.device)
    return (h0_encoder, c0_encoder)
  • 8
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值