基本使用
# 拼接a,b,c的维度1
torch.cat([a,b,c],dim=1)
举例
#假设 abc均为3*2*2
a:tensor([[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]]])
b:tensor([[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]]])
c:tensor([[[2., 2.],
[2., 2.]],
[[2., 2.],
[2., 2.]],
[[2., 2.],
[2., 2.]]])
print(torch.cat([a,b,c],dim=0)
#结果 :9*2*2
tensor([[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.]],
[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]],
[[0., 0.],
[0., 0.]],
[[2., 2.],
[2., 2.]],
[[2., 2.],
[2., 2.]],
[[2., 2.],
[2., 2.]]])
print(torch.cat([a,b,c],dim=1)
#结果 :3*6*2
tensor([[[1., 1.],
[1., 1.],
[0., 0.],
[0., 0.],
[2., 2.],
[2., 2.]],
[[1., 1.],
[1., 1.],
[0., 0.],
[0., 0.],
[2., 2.],
[2., 2.]],
[[1., 1.],
[1., 1.],
[0., 0.],
[0., 0.],
[2., 2.],
[2., 2.]]])
print(torch.cat([a,b,c],dim=2)
#结果 :3*2*6
tensor([[[1., 1., 0., 0., 2., 2.],
[1., 1., 0., 0., 2., 2.]],
[[1., 1., 0., 0., 2., 2.],
[1., 1., 0., 0., 2., 2.]],
[[1., 1., 0., 0., 2., 2.],
[1., 1., 0., 0., 2., 2.]]])
应用
- GoogLeNet中Inception结构将四个并行的branch结果的通道合并:
第一个维度为batch,第二个为channel,后面为宽和高
br1:Tensor:(64,64,28,28)
br2:Tensor:(64,128,28,28)
br3:Tensor:(64,32,28,28)
br4:Tensor:(64,32,28,28)
- 将四个branch结果通道合并:(只需合并第一维度channel)
torch.cat([br1, br2, br3, br4], 1)