PyTorch 中的 stack() 和 cat()

PyTorch 中的 stack() 和 cat()

stack() 可以加入2个或更多张量的序列,如下所示:

import torch

tensor1 = torch.tensor(2) # The size is [].
tensor2 = torch.tensor(7) # The size is [].
tensor3 = torch.tensor(4) # The size is [].
torch.stack((tensor1, tensor2, tensor3))
# tensor([2, 7, 4])
# The size is [3].

tensor1 = torch.tensor([2, 7, 4]) # The size is [3].
tensor2 = torch.tensor([8, 3, 2]) # The size is [3].
tensor3 = torch.tensor([5, 0, 8]) # The size is [3].
torch.stack((tensor1, tensor2, tensor3))
# tensor([[2, 7, 4], [8, 3, 2], [5, 0, 8]])
# The size is [3, 3].

tensor1 = torch.tensor([[2, 7, 4], [8, 3, 2]]) # The size is [2, 3].
tensor2 = torch.tensor([[5, 0, 8], [3, 6, 1]]) # The size is [2, 3].
tensor3 = torch.tensor([[9, 4, 7], [1, 0, 5]]) # The size is [2, 3].
torch.stack((tensor1, tensor2, tensor3))
# tensor([[[2, 7, 4], [8, 3, 2]],
#         [[5, 0, 8], [3, 6, 1]],
#         [[9, 4, 7], [1, 0, 5]]])
# The size is [3, 2, 3].

tensor1 = torch.tensor([[[2, 7, 4], [8, 3, 2]],
                        [[5, 0, 8], [3, 6, 1]]])
                       # The size is [2, 2, 3].
tensor2 = torch.tensor([[[9, 4, 7], [1, 0, 5]],
                        [[6, 7, 4], [2, 1, 9]]])
                       # The size is [2, 2, 3].
tensor3 = torch.tensor([[[1, 6, 3], [9, 6, 0]],
                        [[0, 8, 7], [3, 5, 2]]])
                       # The size is [2, 2, 3].
torch.stack((tensor1, tensor2, tensor3))
torch.stack((tensor1, tensor2, tensor3), 0)
# tensor([[[[2, 7, 4], [8, 3, 2]],
#          [[5, 0, 8], [3, 6, 1]]],
#         [[[9, 4, 7], [1, 0, 5]],
#          [[6, 7, 4], [2, 1, 9]]],
#         [[[1, 6, 3], [9, 6, 0]],
#          [[0, 8, 7], [3, 5, 2]]]])
# The size is [3, 2, 2, 3].

torch.stack((tensor1, tensor2, tensor3), 1)
torch.stack((tensor1, tensor2, tensor3), -3)
# tensor([[[[2, 7, 4], [8, 3, 2]],
#          [[9, 4, 7], [1, 0, 5]],
#          [[1, 6, 3], [9, 6, 0]]],
#         [[[5, 0, 8], [3, 6, 1]],
#          [[6, 7, 4], [2, 1, 9]],
#          [[0, 8, 7], [3, 5, 2]]]])
# The size is [2, 3, 2, 3].

torch.stack((tensor1, tensor2, tensor3), 2)
torch.stack((tensor1, tensor2, tensor3), -2)
# tensor([[[[2, 7, 4], [9, 4, 7], [1, 6, 3]],
#          [[8, 3, 2], [1, 0, 5], [9, 6, 0]]],
#         [[[5, 0, 8], [6, 7, 4], [0, 8, 7]],
#          [[3, 6, 1], [2, 1, 9], [3, 5, 2]]]])
# The size is [2, 2, 3, 3].

torch.stack((tensor1, tensor2, tensor3), 3)
torch.stack((tensor1, tensor2, tensor3), -1)
# tensor([[[[2, 9, 1], [7, 4, 6], [4, 7, 3]],
#          [[8, 1, 9], [3, 0, 6], [2, 5, 0]]],
#         [[[5, 6, 0], [0, 7, 8], [8, 4, 7]],
#          [[3, 2, 3], [6, 1, 5], [1, 9, 2]]]])
# The size is [2, 2, 3, 3].

备忘录:

  • stack() 可以连接0D或更多D张量。
  • 张量的大小必须相同。
  • 将维度设置为第二个参数可以更改大小(形状)。
  • 如果至少一个张量至少包含一个浮动-点号,结果是浮动的张量-点号。

cat () 串联 seq 2个或更多张量,如下所示:

import torch

tensor1 = torch.tensor([2, 7, 4]) # The size is [3].
tensor2 = torch.tensor([8, 3, 2]) # The size is [3].
tensor3 = torch.tensor([5, 0, 8]) # The size is [3].
torch.cat((tensor1, tensor2, tensor3))
# tensor([2, 7, 4, 8, 3, 2, 5, 0, 8])
# The size is [9].

tensor1 = torch.tensor([[2, 7, 4], [8, 3, 2]]) # The size is [2, 3].
tensor2 = torch.tensor([[5, 0, 8], [3, 6, 1]]) # The size is [2, 3].
tensor3 = torch.tensor([[9, 4, 7], [1, 0, 5]]) # The size is [2, 3].
torch.cat((tensor1, tensor2, tensor3))
# tensor([[2, 7, 4],
#         [8, 3, 2],
#         [5, 0, 8],
#         [3, 6, 1],
#         [9, 4, 7],
#         [1, 0, 5]])
# The size is [6, 3].

tensor1 = torch.tensor([[[2, 7, 4], [8, 3, 2]],
                        [[5, 0, 8], [3, 6, 1]]])
                       # The size is [2, 2, 3].
tensor2 = torch.tensor([[[9, 4, 7], [1, 0, 5]],
                        [[6, 7, 4], [2, 1, 9]]])
                       # The size is [2, 2, 3].
tensor3 = torch.tensor([[[1, 6, 3], [9, 6, 0]],
                        [[0, 8, 7], [3, 5, 2]]])
                       # The size is [2, 2, 3].
torch.cat((tensor1, tensor2, tensor3))
torch.cat((tensor1, tensor2, tensor3), 0)
torch.cat((tensor1, tensor2, tensor3), -3)
# tensor([[[2, 7, 4], [8, 3, 2]],
#         [[5, 0, 8], [3, 6, 1]],
#         [[9, 4, 7], [1, 0, 5]],
#         [[6, 7, 4], [2, 1, 9]],
#         [[1, 6, 3], [9, 6, 0]],
#         [[0, 8, 7], [3, 5, 2]]])
# The size is [6, 2, 3].

torch.cat((tensor1, tensor2, tensor3), 1)
torch.cat((tensor1, tensor2, tensor3), -2)
# tensor([[[2, 7, 4], 
#          [8, 3, 2],
#          [9, 4, 7],
#          [1, 0, 5],
#          [1, 6, 3],
#          [9, 6, 0]],
#         [[5, 0, 8],
#          [3, 6, 1],
#          [6, 7, 4],
#          [2, 1, 9],
#          [0, 8, 7],
#          [3, 5, 2]]])
# The size is [2, 6, 3].

torch.cat((tensor1, tensor2, tensor3), 2)
torch.cat((tensor1, tensor2, tensor3), -1)
# tensor([[[2, 7, 4, 9, 4, 7, 1, 6, 3],
#          [8, 3, 2, 1, 0, 5, 9, 6, 0]],
#         [[5, 0, 8, 6, 7, 4, 0, 8, 7],
#          [3, 6, 1, 2, 1, 9, 3, 5, 2]]])
# The size is [2, 2, 9].

备忘录:

  • cat() 可以连接1D或更多张力张量。
  • cat() ,对于某些尺寸不同的张量可能是可能的,但是对于某些尺寸不同的张量是不可能的。
  • 将维度设置为第二个参数可以更改大小(形状)。
  • 如果至少一个张量至少包含一个浮动-点号,结果是浮动的张量-点号。
  • cat()concat() 是一样的,因为 concat()cat() 的别名。
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

琴歌声声送我

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值