【PyTorch】TypeError: stack(): argument 'tensors' (position 1) must be tuple

0x00 前言

近期有个版本适配的任务,说白了就是把 PyTorch 0.3.0 的代码更新适配 PyTorch 1.0.2
PyTorch 的向上兼容性在此时就可以体现出来了,令人欣慰的是直接升级版本后并没有太多报错,
其中一个比较突出的问题就是 torch.stack()torch.cat() 的变化。

0x01 错误演示 @ torch.stack()

>>> x = torch.randn(3, 2, 3)
>>> torch.stack(x, 0)

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-16-97ff2667b2e9> in <module>()
      1 import torch
      2 x = torch.randn(3, 2, 3)
----> 3 torch.stack(x, 0)

TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor

现在 torch.stack() 的输入必须是一个 python-list ,里面每个元素都是一个 tensor,
不可以和原先 0.3.0 的时候一样,输入一个矩阵直接按某一个维度进行 stack 操作。

0x02 错误演示 @ torch.cat()

>>> x = torch.randn(3)
>>> x
tensor([ 0.5710,  0.4324, -0.5154])

>>> x.sum()
tensor(0.4879)

>>> torch.cat([x.sum(), x.sum()], 0)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-22-38992e81433c> in <module>()
----> 1 torch.cat([x.sum(), x.sum()], 0)

RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated

torch.cat() 现在不允许将 0 维 tensor 拼在一起了,如果有类似需求可以考虑直接新建一个 Tensor。

0x03 正确示范

>>> x = torch.randn(2, 3)
>>> x

 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x3]

>>> torch.cat([x, x, x], 0)

 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 6x3]

>>> torch.cat((x, x, x), 1)

 0.5983 -0.0341  2.4918  0.5983 -0.0341  2.4918  0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735  1.5981 -0.5265 -0.8735  1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x9]

0x04 代码更新 DIFF实例

刚学pytorch的时候写的代码,现在看起来……
好想重构啊,算了任务太多来不及以后再说吧

torch.stack()

# Original
t1 = torch.stack(t.repeat(seg_len, 1, 1, 1), 2).view(-1, hid_size)
t2 = torch.stack(t.repeat(seg_len, 1, 1, 1), 1).view(-1, hid_size)

# New
rep_t = [t] * seg_len
t1 = torch.stack(rep_t, 2).view(-1, hid_size)
t2 = torch.stack(rep_t, 1).view(-1, hid_size)

torch.cat()

# Original
macro_count = torch.stack(
            [torch.cat([positive_true.sum(), positive_false.sum()], 0),
             torch.cat([negative_true.sum(), negative_false.sum()], 0)], 1)
# New
macro_count = torch.LongTensor([
            [positive_true.sum(), positive_false.sum()],
            [negative_true.sum(), negative_false.sum()]])

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

糖果天王

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值