torch.stack
总结:先看有几个t,两个t那么在拓展维度上(dim=0,1,2)的数字变成2。
三个t那么(dim=0,1,2)的数字变成3。
flag = True
# flag = False
if flag:
t = torch.ones((2, 3))
t_stack = torch.stack([t, t], dim=0)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
'''
第0维,就是第一个数字,组合就是:2,2,3。就是2个2行3列的数组。
t_stack:tensor([[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]]]) shape:torch.Size([2, 2, 3])
Process finished with exit code 0
'''
flag = True
# flag = False
if flag:
t = torch.ones((2, 3))
t_stack = torch.stack([t, t, t], dim=0)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
'''
相比于torch.stack([t,t],dim=0)的2,2,3变成了torch.stack([t,t,t],dim=0)的3,2,3。
t_stack:tensor([[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]]]) shape:torch.Size([3, 2, 3])
Process finished with exit code 0
'''
flag = True
# flag = False
if flag:
t = torch.ones((2, 3))
t_stack = torch.stack([t, t], dim=1)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
'''
在维度1的位置上stack,也就是变成了:2,2,3。2个2行3列的数组。
t_stack:tensor([[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]]]) shape:torch.Size([2, 2, 3])
Process finished with exit code 0
'''
flag = True
# flag = False
if flag:
t = torch.ones((2, 3))
t_stack = torch.stack([t, t, t], dim=1)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
'''
在维度dim=1的位置进行了torch.stack([t,t,t])拓维操作,由于是3个t,所以dim=1的位置上由原来的torch.stack([t,t],dim=1)的2,2,3变成了2,3,3。
'''
t_stack:tensor([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]]) shape:torch.Size([2, 3, 3])
Process finished with exit code 0
# torch.stack
flag = True
# flag = False
if flag:
t = torch.ones((2, 3))
t_stack = torch.stack([t, t], dim=2)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
'''
这个地方的dim=2,自己创建了个维度,而且数值就是2,那么就形成了,2,3,2。无论前面两个数字怎么变化,由于torch.stack([t,t]),这个地方是2个t,所以第三个维度的数值始终都是2。就是2个3行2列的数组。
t_stack:tensor([[[1., 1.],
[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.],
[1., 1.]]]) shape:torch.Size([2, 3, 2])
Process finished with exit code 0
'''
flag = True
# flag = False
if flag:
t = torch.ones((2, 3))
t_stack = torch.stack([t, t, t], dim=2)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))
'''
在这段代码中的torch.stach([t,t,t]),有3个t,再在第三个维度上dim=2进行创建,尺寸会变成2个3行3列的数组。
t_stack:tensor([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]]) shape:torch.Size([2, 3, 3])
'''