【Pytorch基础】torch.stack()函数解析

1 函数作用

  官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状,注意与torch.cat的区别
  浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。

2 例子

import torch
import numpy as np
# 创建3*3的矩阵,a、b
a=np.array([[1,2,3],[4,5,6],[7,8,9]])
b=np.array([[10,20,30],[40,50,60],[70,80,90]])
# 将矩阵转化为Tensor
a = torch.from_numpy(a)
b = torch.from_numpy(b)

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]], dtype=torch.int32)
tensor([[10, 20, 30],
        [40, 50, 60],
        [70, 80, 90]], dtype=torch.int32)

2.1 沿dim=0拼接

d = torch.stack((a, b), dim=0)
print(d)
print(d.size())
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]], dtype=torch.int32)
torch.Size([2, 3, 3])

  当dim = 0,原来的每一个矩阵也变成了一个维度,一个矩阵看做一个整体。有几个矩阵,新的维度就是几,第几个矩阵就是第几维。

2.2 dim=1

d = torch.stack((a, b), dim=1)
print(d)
print(d.size())
tensor([[[ 1,  2,  3],
         [10, 20, 30]],

        [[ 4,  5,  6],
         [40, 50, 60]],

        [[ 7,  8,  9],
         [70, 80, 90]]], dtype=torch.int32)
torch.Size([3, 2, 3])

  将每个矩阵的第一行组成第一维矩阵,依次下去,每个矩阵的第n行组成第n维矩阵。size=(n,i,y)

2.3 dim=2

d = torch.stack((a, b), dim=2)
print(d)
print(d.size())
tensor([[[ 1, 10],
         [ 2, 20],
         [ 3, 30]],

        [[ 4, 40],
         [ 5, 50],
         [ 6, 60]],

        [[ 7, 70],
         [ 8, 80],
         [ 9, 90]]], dtype=torch.int32)
torch.Size([3, 3, 2])

dim=2的理解可以参考文献【3】

3 参考文献

[1]【Pytorch】torch.stack()的使用
[2]看完秒懂torch.stack()
[3]初学torch.stack()对dim的个人理解

  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值