Pytorch中cat和stack的用法

详解cat和stack

本文为原创,仅供交流学习,转载请注明出处,谢谢!

Pytorch学习说明

在这里推荐两份文档:
1.Pytorch中文手册:这就是一本"新华字典"。
2.动手学习深度学习:进一步学习DeepLearning
最近在学习Pytorch,在CNN的部分遇到了torch.stack()torch.cat()两个函数。在网上查阅了很多博客,才搞清楚这两个函数的作用。在这里稍微总结一下。

torch.Cat()

格式说明:

torch.cat(inputs, dimension=0) → Tensor

参数:
inputs (sequence of Tensors) – 可以是任意相同Tensor 类型的python 序列
dimension (int, optional) – 沿着此维连接张量序列。


这个函数还是比较好理解的。
我们举个栗子:
我们定义了张量A,张量B(tensor类型也叫张量)。

A=torch.tensor([[1,2,3],[4,5,6]],dtype=torch.float)
print("A:",A)
B=torch.tensor([[-1,-2,-3],[-4,-5,-6],[-7,-8,-9]],dtype=torch.float)
print("B:",B)

输出结果:

A: tensor([[1., 2., 3.],
        [4., 5., 6.]])
B: tensor([[-1., -2., -3.],
        [-4., -5., -6.],
        [-7., -8., -9.]])

我们可以看到A的尺寸:(2,3), B的尺寸:(3,3)。这里有个小tip:dtype=torch.float,不然会报错。


1. dim=0的情况

接下来我们在dim=0上执行cat,代码如下:

print("dim=0:",torch.cat((A,B),dim=0))

输出结果:

dim=0: tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [-1., -2., -3.],
        [-4., -5., -6.],
        [-7., -8., -9.]])

其实就是将A,B两个张量垂直拼接思考一下:如果两个张量的列不相同会怎么样,还能在dim=0上做cat操作吗?
我们将B的尺寸改为(3,4)结果如下:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-51-b2487a7f4fe4> in <module>()
      4 print("B:",B)
      5 print("*********************************")
----> 6 print("dim=0:",torch.cat((A,B),dim=0))

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 4 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:711

2. dim=1的情况
类似与dim=0,dim=1其实就是将两个张量水平拼接起来。同样的如果两个张量行号不相同,还是会报错。

A=torch.tensor([[1,2,3],[4,5,6]],dtype=torch.float)
print("A:",A)
B=torch.tensor([[-1,-2,-3],[-4,-5,-6],[-7,-8,-9]],dtype=torch.float)
print("B:",B)
print("*********************************")
print("dim=1:",torch.cat((A,B),dim=1))
A: tensor([[1., 2., 3.],
        [4., 5., 6.]])
B: tensor([[-1., -2., -3.],
        [-4., -5., -6.],
        [-7., -8., -9.]])
*********************************

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-52-a0255954420e> in <module>()
      4 print("B:",B)
      5 print("*********************************")
----> 6 print("dim=1:",torch.cat((A,B),dim=1))

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 2 and 3 in dimension 0 at /pytorch/aten/src/TH/generic/THTensor.cpp:711

修改下A,B后,演示如下:

A=torch.tensor([[1,2,3],[4,5,6],[7,8,9]],dtype=torch.float)
print("A:",A)
B=torch.tensor([[-1,-2,-3],[-4,-5,-6],[-7,-8,-9]],dtype=torch.float)
print("B:",B)
print("*********************************")
print("dim=1:",torch.cat((A,B),dim=1))

结果如下:

A: tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
B: tensor([[-1., -2., -3.],
        [-4., -5., -6.],
        [-7., -8., -9.]])
*********************************
dim=1: tensor([[ 1.,  2.,  3., -1., -2., -3.],
        [ 4.,  5.,  6., -4., -5., -6.],
        [ 7.,  8.,  9., -7., -8., -9.]])

我们可以看出输出的是A,B水平拼接的结果。

torch.stack()

格式说明:

torch.stack(sequence, dim=0)

沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
参数:
sqequence (Sequence) – 待连接的张量序列
dim (int) – 插入的维度。必须介于 0 与 待连接的张量序列数之间。


和cat一样我们,在这里举个例子说明:
定义了A,B,C三个张量,这里要求ABC都为相同形状。

A=torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
B=10*A
C=100*A
print("A:",A)
print("B:",B)
print("C:",C)

结果如下:

A: tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
B: tensor([[10, 20, 30],
        [40, 50, 60],
        [70, 80, 90]])
C: tensor([[100, 200, 300],
        [400, 500, 600],
        [700, 800, 900]])


1. dim=0

d0=torch.stack((A,B,C),dim=0)
print("dim=0:",d0)
print("d0[0][0][0]:",d0[0][0][0])
print("d0[0][0]:",d0[0][0])
print("d0[0]:",d0[0])

结果如下:

dim=0: tensor([[[  1,   2,   3],
         [  4,   5,   6],
         [  7,   8,   9]],

        [[ 10,  20,  30],
         [ 40,  50,  60],
         [ 70,  80,  90]],

        [[100, 200, 300],
         [400, 500, 600],
         [700, 800, 900]]])
d0[0][0][0]: tensor(1)
d0[0][0]: tensor([1, 2, 3])
d0[0]: tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

我们输出了d1的第2,1,0维元素,我们可以清楚的看到当dim=0时候,torch.stack类似于
torch.cat。我们可以看出,dim=0的作用其实就是把所有的d[i]拼接在了一起。简单的说,将A,B,C垂直拼接。


2. dim=1

d1=torch.stack((A,B,C),dim=1)
print("dim=1:",d1)
print("d1[0][0][0]:",d1[0][0][0])
print("d1[0][0]:",d1[0][0])
print("d1[0]:",d1[0])

结果如下:

dim=1: tensor([[[  1,   2,   3],
         [ 10,  20,  30],
         [100, 200, 300]],

        [[  4,   5,   6],
         [ 40,  50,  60],
         [400, 500, 600]],

        [[  7,   8,   9],
         [ 70,  80,  90],
         [700, 800, 900]]])
d1[0][0][0]: tensor(1)
d1[0][0]: tensor([1, 2, 3])
d1[0]: tensor([[  1,   2,   3],
        [ 10,  20,  30],
        [100, 200, 300]])

我们可以看出,dim=1的作用其实就是把ABC的d[i][i]拼接在了一起。也就是说将每个张量的相同位置的行向量拼接在一起。如A[1],B[1],C[1]=[4,5,6],[40,50,60],[400,500,600]
3. dim=2

d2=torch.stack((A,B,C),dim=2)
print("dim=2",d2)
print("d2[0][0][0]:",d2[0][0][0])
print("d2[0][0]:",d2[0][0])
print("d2[0]:",d2[0])

结果如下:

dim=2 tensor([[[  1,  10, 100],
         [  2,  20, 200],
         [  3,  30, 300]],

        [[  4,  40, 400],
         [  5,  50, 500],
         [  6,  60, 600]],

        [[  7,  70, 700],
         [  8,  80, 800],
         [  9,  90, 900]]])
d2[0][0][0]: tensor(1)
d2[0][0]: tensor([  1,  10, 100])
d2[0]: tensor([[  1,  10, 100],
        [  2,  20, 200],
        [  3,  30, 300]])

我们可以看出,dim=2的作用其实就是把ABC的d[i][i][i]拼接在了一起。即将A,B,C三个张量的相同位置的元素拼接在一起。如A[1][1],B[1][1],C[1][1]=5,50,500

本人pytorch小白,完全为了交流学习。仅为个人理解,如果有错误地方请指出。欢迎一起交流学习。

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值