一.torch.cat()函数解析
函数功能
函数将多个张量(tensor)按指定维度拼接在一起,在卷积中通常把相同2D维度,不同通道数的的张量进行拼接(Concatenation),则可使用torch.cat(),不会增加新的维度。
四阶张量有4个维度(a,b,c,d),其对应维度标号如下:
a->dim=0;
b->dim=1;
c->dim=2;
d->dim=4;
按dim=1进行拼接,则按通道维度进行组合
代码操作
输入:
import torch
a = torch.randn(1,2,4,4)
b = torch.randn(1,2,4,4)
c = torch.cat((a,b),dim=1)
print(c.shape)
print(c)
输出:
torch.Size([1, 4, 4, 4])
tensor([[[[-1.2265, 0.2191, -0.2701, -1.1974],
[-0.2012, 0.2933, -0.7879, -0.8292],
[ 0.7499, 0.9272, -0.2942, -0.7157],
[ 1.2976, 0.3106, -1.3818, -0.4638]],
[[ 0.8694, -0.5793, 1.6817, -0.5349],
[ 0.3290, -0.1985, -0.7746, -1.4872],
[ 0.7211, 1.0945, 0.5810, -0.3578],
[-2.0604, 0.9210, -0.2202, 2.3572]],
[[ 1.1119, -1.4924, -0.8243, -0.6296],
[ 0.6080, 0.2288, 0.2951, 0.0872],
[-0.6361, 0.8789, -0.5911, -0.1276],
[-0.6944, -0.0813, 1.2276, 2.0860]],
[[-0.9725, 0.4640, -0.3887, -0.3565],
[-0.4997, 1.3356, -0.2833, -1.3948],
[-0.1654, 0.0709, 1.6289, 0.0100],
[-0.2617, 0.1561, 1.2905, -0.8922]]]])
进程已结束,退出代码0