cat是concatnate的缩写,concatnate的含义拼接,深度学习模型中最为常见的是通道拼接。
一、普通用法
- dim = 1:表示将张量A、B按照维数1进行拼接,换句话说,就是按照列进行拼接
torch.cat((A,B),dim = 1)
- 案例1
x = torch.randn(3, 4)
y = torch.randn(3, 2)
print(x, x.size())
print(y, y.size())
z = torch.cat((x, y), dim = )
print(z, z.size())
- dim = 0:表示将张量A、B按照维数0进行拼接,换句话说,就是按照行进行拼接
torch.cat((A,B),dim = 0)
- 案例2
x = torch.randn(2,3)
y = torch.randn(5,3)
print(x)
print(y)
z = torch.cat((x,y),dim = 0)
print(z)
二、进阶用法
除上述普通用法外,torch.cat()也可以将一个列表中的tensor拼接起来。
lst = []
x = torch.randn(3,4)
y = torch.randn(2,4)
print(x, x.size())
print(y, y.size())
lst.append(x)
lst.append(y)
z = torch.cat(lst, dim = 0)
print(z, z.size())
之后我会尽量每天都会更新一篇PyTorch的小知识点,不积硅步,无以至千里,只要每天积累一点点,一定会有提升的!希望这篇文章对大家有帮助!