import torch
c = torch.randn((2,3,4))
d = torch.randn((2,3,4))
print(c,d)
print(c.shape,d.shape)
print(torch.cat((c,d),dim=0))
print(torch.cat((c,d),dim=1))
print(torch.cat((c,d),dim=2))
m = torch.cat((c,d),dim=0)
n = torch.cat((c,d),dim=1)
l = torch.cat((c,d),dim=2)
print(m.shape,n.shape ,l.shape)
tensor([[[ 1.1613, -0.4922, 1.9275, 0.8882],
[-0.2998, 0.7337, 0.3202, 0.9770],
[ 1.4552, 0.0873, 0.7061, 1.0533]],
[[-1.2791,