pytorch入门(2)torch.stack(特别绕的维度)

torch.stack

总结:先看有几个t,两个t那么在拓展维度上(dim=0,1,2)的数字变成2。
三个t那么(dim=0,1,2)的数字变成3。

flag = True
# flag = False
if flag:
    t = torch.ones((2, 3))
    t_stack = torch.stack([t, t], dim=0)
    print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
'''
第0维,就是第一个数字,组合就是:2,2,3。就是2个2行3列的数组。
t_stack:tensor([[[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]]]) shape:torch.Size([2, 2, 3])

Process finished with exit code 0
'''
flag = True
# flag = False
if flag:
    t = torch.ones((2, 3))
    t_stack = torch.stack([t, t, t], dim=0)
    print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
'''
相比于torch.stack([t,t],dim=0)的2,2,3变成了torch.stack([t,t,t],dim=0)的3,2,3。
t_stack:tensor([[[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]]]) shape:torch.Size([3, 2, 3])

Process finished with exit code 0
'''

flag = True
# flag = False
if flag:
    t = torch.ones((2, 3))
    t_stack = torch.stack([t, t], dim=1)
    print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
'''
在维度1的位置上stack,也就是变成了:2,2,3。2个2行3列的数组。
t_stack:tensor([[[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]]]) shape:torch.Size([2, 2, 3])

Process finished with exit code 0
'''
flag = True
# flag = False
if flag:
    t = torch.ones((2, 3))
    t_stack = torch.stack([t, t, t], dim=1)
    print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
'''
在维度dim=1的位置进行了torch.stack([t,t,t])拓维操作,由于是3个t,所以dim=1的位置上由原来的torch.stack([t,t],dim=1)的2,2,3变成了2,3,3。
'''
t_stack:tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]]) shape:torch.Size([2, 3, 3])

Process finished with exit code 0
# torch.stack
flag = True
# flag = False
if flag:
    t = torch.ones((2, 3))
    t_stack = torch.stack([t, t], dim=2)
    print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
'''
这个地方的dim=2,自己创建了个维度,而且数值就是2,那么就形成了,2,3,2。无论前面两个数字怎么变化,由于torch.stack([t,t]),这个地方是2个t,所以第三个维度的数值始终都是2。就是2个3行2列的数组。
t_stack:tensor([[[1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.]]]) shape:torch.Size([2, 3, 2])

Process finished with exit code 0
'''

flag = True
# flag = False
if flag:
    t = torch.ones((2, 3))
    t_stack = torch.stack([t, t, t], dim=2)
    print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
'''
在这段代码中的torch.stach([t,t,t]),有3个t,再在第三个维度上dim=2进行创建,尺寸会变成2个3行3列的数组。
t_stack:tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]]) shape:torch.Size([2, 3, 3])
'''
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值