pytorch:chunks为0的报错

File "D:/Codes/code/Python Project/group_reid-master/group_reid-master/main_group_gcn_siamese_part_half_fulltest_sink.py", line 348, in train_gcn
    loss.backward()
  File "D:\Codes\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "D:\Codes\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\autograd\__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: chunk expects `chunks` to be greater than 0, got: 0
Exception raised from chunk at ..\aten\src\ATen\native\TensorShape.cpp:496 (most recent call first):

如图,一直报错chunk为0,最开始百思不得其解,在网上找资料没有和我类似的情况,键入报错的地方发现在loss求导时报错(下图),我寻思loss是调用的函数,不可能会报这种错误啊。于是直接丢在服务器上debug试试,由于pytorch的版本不同,报错的内容也不同。在1.1版本的pytorch上终于找到出错的地方了。

            loss.backward()

原来是在切块的时候维度不匹配导致的:

env11_junk1 = env11.squeeze().unsqueeze(0).unsqueeze(0).repeat((5-x1_valid.shape[0]), parts, 1)
env22_junk2 = env22.squeeze().unsqueeze(0).unsqueeze(0).repeat((5-x2_valid.shape[0]), parts, 1)
env11 = env11.squeeze().unsqueeze(0).unsqueeze(0).repeat(x1_valid.shape[0], parts, 1)
env22 = env22.squeeze().unsqueeze(0).unsqueeze(0).repeat(x2_valid.shape[0], parts, 1)

  # calculate within graph and inter graph message
h_k1 = torch.cat((self.W_x(x1[i, :sample_size1, :]), self.W_neib(x_neib1), self.W_relative(mu1), self.W_env(env11)), 2).unsqueeze(0)  
h_k_junk1 = torch.cat((self.W_x(x1[i, sample_size1:, :]), self.W_x(x1[i, sample_size1:, :]), self.W_x(x1[i, sample_size1:, :]),self.W_env(env11_junk1)), 2).unsqueeze(0)

h_k2 = torch.cat((self.W_x(x2[i, :sample_size2, :]), self.W_neib(x_neib2), self.W_relative(mu2), self.W_env(env22)), 2).unsqueeze(0)
h_k_junk2 = torch.cat((self.W_x(x2[i, sample_size2:, :]), self.W_x(x2[i, sample_size2:, :]), self.W_x(x2[i, sample_size2:, :]),self.W_env(env22_junk2)), 2).unsqueeze(0)                       

在我的代码中(squeeze和unsqueeze用的多余了,直接premute就好了),我本意是想在第一维度上复制和X1相同的第一维度,但是在实际的数据集中X1的第一维度有可能是为0的,因此repeat的参数为0的话就会报错了,不能小于原始的dim。由于报的错误不明显,所以耽误了半天时间想这个问题,特此记录一下。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值