torch.stack(sequence, dim)
sqequence– 待连接的张量序列
dim (int) – 插入的维度。
torch.stack()函数和torch.cat()有所不同,torch.stack()并不在已有的维度进行拼接,而是沿着新的维度进行拼接。
我在使用torch.stack()产生了两个问题:
1.怎么确定新的维度产生在哪里?
2.指定了新维度后要怎么拼接?
下面我以两个张量来说明,分别是A和B
A = torch.arange(6.0).reshape(2,3)
B = torch.linspace(0,10,6).reshape(2,3)
A和B是这样的2维张量
A= tensor([[0., 1., 2.],
[3., 4., 5.]])
B= tensor([[ 0., 2., 4.],
[ 6., 8., 10.]])
下面说说我对torch.stack()函数的使用理解:
既然知道函数会为张量产生一个新维度,那么我们可以假设,令A和B维度升级,从(2,3)变为(1,2,3),即:
A1= tensor([[[0., 1., 2.],
[3., 4., 5.]]])
B1= tensor([[[ 0., 2., 4.],
[ 6., 8., 10.]]])
#A1、B1比A、B在最外层多了一组括号[]
这样,接下里就很好解释了。
参数dim表示相连维度在这3维里的索引,如用link表示连接维度:
dim=0时,(link,#,#)
dim=1时,(#,link,#)
dim=2时,(#,#,link)
link所在的维度是哪个,就把A1和B1对应维度里的元素逐个相连。
下面我对每个维度都演示一遍
dim=0
F1 = torch.stack((A,B),dim=0)
print('F1=',F1)
print('F1的维度是',F1.shape)
运行结果:
F1= tensor([[[ 0., 1., 2.],
[ 3., 4., 5.]],
[[ 0., 2., 4.],
[ 6., 8., 10.]]])
F1的维度是 torch.Size([2, 2, 3])
用上面的说法来理解,这相当于在A1和B1的第0维度里,每个元素依次相连,每对连接元素用[ ]包装。
A1的第0维度里只有A一个元素
B1的第0维度里只有B一个元素
因此如上所示,F1的结果其实就是[A,B]
dim=1
F2 = torch.stack((A,B),dim=1)
print('F2=',F2)
print('F2的维度是',F2.shape)
运行结果:
F2= tensor([[[ 0., 1., 2.],
[ 0., 2., 4.]],
[[ 3., 4., 5.],
[ 6., 8., 10.]]])
F2的维度是 torch.Size([2, 2, 3])
沿用上面的理解
A1的第1维度里的两个元素:[0. 1, 2],[3, 4, 5]
B1的第1维度里的两个元素:[0, 2, 4],[6, 8, 10]
[0. 1, 2]和[0, 2, 4]相连,[ ]包起来
[3, 4, 5]和[6, 8, 10]相连,[ ]包起来
最后给以上两组用[ ]包起来
dim=2
F3 = torch.stack((A,B),dim=2)
print('F3=',F3)
print('F3的维度是',F3.shape)
运行结果
F3= tensor([[[ 0., 0.],
[ 1., 2.],
[ 2., 4.]],
[[ 3., 6.],
[ 4., 8.],
[ 5., 10.]]])
F3的维度是 torch.Size([2, 3, 2])
A1和B1的第2维度里每个元素依次相连,每对连接元素用[]包装
A1第2维度里的元素:0,1,2,3,4,5
B1第2维度里的元素:0,2,4,6,8,10
两两相连后打包[0,0] [1,2] [2,4] [3,6] [4,8] [5,10]
由于[0,1,2]和[0,2,4]的第1维度属性是0
[3,4,5]和[6,8,10]的第1维度属性是1
所以把第一维度属性是0的和是1的单独打包
即[[0,0],[1,2],[2,4]]和[[3,6],[4,8],[5,10]]
最后将以上两组一起[]包起来