@torch多维向量的cat
a0 = torch.Tensor([[[[1,1,1],[2,2,2]]]])
a1 = torch.Tensor([[[[3,3,3],[4,4,4]]]])
x = torch.cat((a0,a1),3).type(torch.FloatTensor)
#拼接对象是数据a0和数据a1,维度是三维的。
print(x)
注意要用type(torch.FloatTensor)。否则报错
@torch多维向量的cat
a0 = torch.Tensor([[[[1,1,1],[2,2,2]]]])
a1 = torch.Tensor([[[[3,3,3],[4,4,4]]]])
x = torch.cat((a0,a1),3).type(torch.FloatTensor)
#拼接对象是数据a0和数据a1,维度是三维的。
print(x)
注意要用type(torch.FloatTensor)。否则报错