torch.cat( )与torch.stack( )

torch.cat( )

官方:连接给定维数的序列张量。所有张量要么具有相同的形状(连接维度除外),要么为空。

 例子:注意每个维度的shape

import torch
x = torch.randn(2, 3, 4)
rx1 = torch.cat((x, x, x), 0)
rx2 = torch.cat((x, x, x), 1)
print(x)
print(rx1)
print(rx2)

结果:

tensor([[[ 0.3334,  0.9134,  1.0101, -1.8516],
         [ 1.2072, -0.2558,  1.1317,  1.6834],
         [ 0.3684,  1.2080, -1.0300,  0.4929]],

        [[-0.0713,  1.4526, -0.0669, -1.2307],
         [ 0.2329, -0.6852, -0.3124,  0.0389],
         [ 1.5162, -0.2998,  0.5022,  0.3386]]])
tensor([[[ 0.3334,  0.9134,  1.0101, -1.8516],
         [ 1.2072, -0.2558,  1.1317,  1.6834],
         [ 0.3684,  1.2080, -1.0300,  0.4929]],

        [[-0.0713,  1.4526, -0.0669, -1.2307],
         [ 0.2329, -0.6852, -0.3124,  0.0389],
         [ 1.5162, -0.2998,  0.5022,  0.3386]],

        [[ 0.3334,  0.9134,  1.0101, -1.8516],
         [ 1.2072, -0.2558,  1.1317,  1.6834],
         [ 0.3684,  1.2080, -1.0300,  0.4929]],

        [[-0.0713,  1.4526, -0.0669, -1.2307],
         [ 0.2329, -0.6852, -0.3124,  0.0389],
         [ 1.5162, -0.2998,  0.5022,  0.3386]],

        [[ 0.3334,  0.9134,  1.0101, -1.8516],
         [ 1.2072, -0.2558,  1.1317,  1.6834],
         [ 0.3684,  1.2080, -1.0300,  0.4929]],

        [[-0.0713,  1.4526, -0.0669, -1.2307],
         [ 0.2329, -0.6852, -0.3124,  0.0389],
         [ 1.5162, -0.2998,  0.5022,  0.3386]]])
tensor([[[ 0.3334,  0.9134,  1.0101, -1.8516],
         [ 1.2072, -0.2558,  1.1317,  1.6834],
         [ 0.3684,  1.2080, -1.0300,  0.4929],
         [ 0.3334,  0.9134,  1.0101, -1.8516],
         [ 1.2072, -0.2558,  1.1317,  1.6834],
         [ 0.3684,  1.2080, -1.0300,  0.4929],
         [ 0.3334,  0.9134,  1.0101, -1.8516],
         [ 1.2072, -0.2558,  1.1317,  1.6834],
         [ 0.3684,  1.2080, -1.0300,  0.4929]],

        [[-0.0713,  1.4526, -0.0669, -1.2307],
         [ 0.2329, -0.6852, -0.3124,  0.0389],
         [ 1.5162, -0.2998,  0.5022,  0.3386],
         [-0.0713,  1.4526, -0.0669, -1.2307],
         [ 0.2329, -0.6852, -0.3124,  0.0389],
         [ 1.5162, -0.2998,  0.5022,  0.3386],
         [-0.0713,  1.4526, -0.0669, -1.2307],
         [ 0.2329, -0.6852, -0.3124,  0.0389],
         [ 1.5162, -0.2998,  0.5022,  0.3386]]])
torch.Size([2, 3, 4])
torch.Size([6, 3, 4])
torch.Size([2, 9, 4])

沿0维拼接,0维度因为3个拼接故从2到6。沿1维拼接,1维度因为3个拼接故从3到9。

torch.stack( )

官方:沿着一个新的维度连接张量序列。所有张量的大小必须相同。

 例子:注意每个维度的shape

import torch
x1 = torch.randn(2, 3)
x2 = torch.randn(2, 3)
rx1 = torch.stack([x1, x2], dim=0)
rx2 = torch.stack([x1, x2], dim=1)
print(x1)
print(x2)
print(rx1)
print(rx2)
print(x1.shape)
print(x2.shape)
print(rx1.shape)
print(rx2.shape)

结果:

tensor([[ 0.3013, -0.2763,  1.1704],
        [ 0.2154,  0.8657, -0.2266]])
tensor([[ 0.2167, -0.5974,  0.9155],
        [-0.1918, -0.4248,  0.9019]])
tensor([[[ 0.3013, -0.2763,  1.1704],
         [ 0.2154,  0.8657, -0.2266]],

        [[ 0.2167, -0.5974,  0.9155],
         [-0.1918, -0.4248,  0.9019]]])
tensor([[[ 0.3013, -0.2763,  1.1704],
         [ 0.2167, -0.5974,  0.9155]],

        [[ 0.2154,  0.8657, -0.2266],
         [-0.1918, -0.4248,  0.9019]]])
torch.Size([2, 3])
torch.Size([2, 3])
torch.Size([2, 2, 3])
torch.Size([2, 2, 3])

Process finished with exit code 0

两个张量,沿0维拼接,产生0维,因为两个拼接故0维为2 。两个张量,沿1维拼接,产生1维,因为两个拼接故1维为2 。

区别:cat是沿维数直接拼接,stack是沿着维数产生维度再进行拼接。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值