网上很多的示例,都在讨论二维数据(矩阵),单是对于做图像与深度学习的人来说均是三维起步,一般都是4维,下边以4维数据举例
对于pytorch中的堆叠与拼接函数stack与cat,二者还是有一定的不同
torch.cat这是一个拼接函数(姑且这么说)
直接上例子
a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.Size([1, 1, 2, 4])
torch.cat((a0,a1),dim=0).type(torch.FloatTensor)
tensor([[[[1., 1., 1., 1.],
[2., 2., 2., 2.]]],
[[[3., 3., 3., 3.],
[4., 4., 4., 4.]]]])
torch.Size([2, 1, 2, 4])
上边dim=0,为以第一维为基准拼接,对于一个张量的维度,有几个放括号就是几维,上边例子a0与a1均为4维张量,因此以第0维拼接就是将第一个中括号内的内容进行拼接。最终的尺度大小为(2,1,2,4)
a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.cat((a0,a1),dim=1).type(torch.FloatTensor)
tensor([[[[1., 1., 1., 1.],
[2., 2., 2., 2.]],
[[3., 3., 3., 3.],
[4., 4., 4., 4.]]]])
torch.Size([1, 2, 2, 4])
以第1维进行拼接,将第二个括号内的内容进行拼接
a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.cat((a0,a1),dim=2).type(torch.FloatTensor)
tensor([[[[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.],
[4., 4., 4., 4.]]]])
torch.Size([1, 1, 4, 4])
a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.cat((a0,a1),dim=3).type(torch.FloatTensor)
tensor([[[[1., 1., 1., 1., 3., 3., 3., 3.],
[2., 2., 2., 2., 4., 4., 4., 4.]]]])
torch.Size([1, 1, 2, 8])
以上为torch.cat的使用方法,两个tensor的维度必须一致,最终生成的张量的维度也没有变化,但是torch.cat就不一样了
torch.stack
依然使用上边的例子,看看这个函数的功能
a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.stack((a0,a1),dim=0).type(torch.FloatTensor)
tensor([[[[[1., 1., 1., 1.],
[2., 2., 2., 2.]]]],
[[[[3., 3., 3., 3.],
[4., 4., 4., 4.]]]]])
torch.Size([2, 1, 1, 2, 4])
a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.stack((a0,a1),dim=1).type(torch.FloatTensor)
tensor([[[[[1., 1., 1., 1.],
[2., 2., 2., 2.]]],
[[[3., 3., 3., 3.],
[4., 4., 4., 4.]]]]])
torch.Size([1, 2, 1, 2, 4])
从上边两个例子,可以看出,对于torch.stack来说,会先将原始数据维度扩展一维,然后再按照维度进行拼接,具体拼接操作同torch.cat类似
贴个torch.stack()官方文档的截图
dim代表沿着哪个维度进行堆叠
举个例子:
dim=0时:(dim不写时,默认为0)
a: 2x3 ; b: 2x3 ; c: 2x2x3
dim=1时:略
dim=2时:
a: 2x3 ; b: 2x3 ; c: 2x3x2
总结
参考链接
https://zhuanlan.zhihu.com/p/70035580
https://www.pianshen.com/article/10611294719/