相同点
都是沿着tensor指定维度进行拼接
不同点
直接通过实例进行展示吧…例如:现在分别输入三个tensor,a,b,c,如下图所示值
# input
import torch
a = torch.IntTensor([[1,2,3],[4,5,6],[7,8,9]]) # torch.Size([3, 3])
b = torch.IntTensor([[11,22,33],[44,55,66],[77,88,99]]) # torch.Size([3, 3])
c = torch.IntTensor([[111,222,333],[444,555,666],[777,888,999]]) # torch.Size([3, 3])
'''
a: tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=torch.int32)
b: tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]], dtype=torch.int32)
c: tensor([[111, 222, 333],
[444, 555, 666],
[777, 888, 999]], dtype=torch.int32)
'''
- torch.cat():只能在已有维度上进行拼接!!——串联
# cat
cat_0 = torch.cat((a,b,c),dim=0) # torch.Size([9, 3])
cat_1 = torch.cat((a,b,c),dim=1) # torch.Size([3, 9])
cat_2 = torch.cat((a,b,c),dim=2) # error
# cat_output
'''
cat_0:
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[ 11, 22, 33],
[ 44, 55, 66],
[ 77, 88, 99],
[111, 222, 333],
[444, 555, 666],
[777, 888, 999]], dtype=torch.int32)
cat_1:
tensor([[ 1, 2, 3, 11, 22, 33, 111, 222, 333],
[ 4, 5, 6, 44, 55, 66, 444, 555, 666],
[ 7, 8, 9, 77, 88, 99, 777, 888, 999]], dtype=torch.int32)
cat_2:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
'''
- torch.stack():增加一维,同时在指定进行拼接——并联
# stack
stack_0 = torch.stack((a,b,c),dim=0) # torch.Size([3, 3, 3])
stack_1 = torch.stack((a,b,c),dim=1) # torch.Size([3, 3, 3])
stack_2 = torch.stack((a,b,c),dim=2) # torch.Size([3, 3, 3])
# stack_output
'''
stack_0:
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[ 11, 22, 33],
[ 44, 55, 66],
[ 77, 88, 99]],
[[111, 222, 333],
[444, 555, 666],
[777, 888, 999]]], dtype=torch.int32)
stack_1:
tensor([[[ 1, 2, 3],
[ 11, 22, 33],
[111, 222, 333]],
[[ 4, 5, 6],
[ 44, 55, 66],
[444, 555, 666]],
[[ 7, 8, 9],
[ 77, 88, 99],
[777, 888, 999]]], dtype=torch.int32)
stack_2:
tensor([[[ 1, 11, 111],
[ 2, 22, 222],
[ 3, 33, 333]],
[[ 4, 44, 444],
[ 5, 55, 555],
[ 6, 66, 666]],
[[ 7, 77, 777],
[ 8, 88, 888],
[ 9, 99, 999]]], dtype=torch.int32)
'''
总结
- torch.cat()对tensors沿指定维度拼接,返回的tensor的维数不会变——串联
- torch.stack()同样也是对tensors沿指定维度拼接,但返回的tensor会多一维。如:对两个“1X2“维的tensor在dim=0上stack,则会变为"2X1X2"的tensor;在dim=1上stack,则会变为"1X2X2"的tensor;在dim=2上stack,也会变为"1X2X2"的tensor,虽然与dim=1上stack最终产生的tensor维度相同,但是叠加的维度是不一样的,结果也是不一样的。——并联