简而言之就是
拼接的是哪个维度 哪个维度就增加!
import torch
z=torch.randn(8,4,32,32)
z = torch.cat([z, z], 0)
print(z.shape)
torch.Size([16, 4, 32, 32])
import torch
z=torch.randn(8,4,32,32)
z = torch.cat([z, z], 2)
print(z.shape)
torch.Size([8, 4, 64, 32])
简而言之就是
拼接的是哪个维度 哪个维度就增加!
import torch
z=torch.randn(8,4,32,32)
z = torch.cat([z, z], 0)
print(z.shape)
torch.Size([16, 4, 32, 32])
import torch
z=torch.randn(8,4,32,32)
z = torch.cat([z, z], 2)
print(z.shape)
torch.Size([8, 4, 64, 32])