一、torch.stack
a = torch.tensor([[1,2,3]])
b = torch.tensor([[4,5,6]])
print(torch.stack((a,b),dim=0))
print(torch.stack((a,b),dim=1))
#输出:
tensor([[[1, 2, 3]],
[[4, 5, 6]]])
tensor([[[1, 2, 3],
[4, 5, 6]]])
二、torch.cat
a = torch.tensor([[1,2,3]])
b = torch.tensor([[4,5,6]])
print(torch.cat((a,b),dim=0))
print(torch.cat((a,b),dim=1))
#输出:
tensor([[1, 2, 3],
[4, 5, 6]])
tensor([[1, 2, 3, 4, 5, 6]])
三、torch.chunk
import torch
a = torch.arange(12).reshape(3,4)
print(a)
print(a.chunk(2,dim=0))
print(a.chunk(2,dim=1))
#输出:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
(tensor([[0, 1, 2, 3],
[4, 5, 6, 7]]), tensor([[ 8, 9, 10, 11]]))
(tensor([[0, 1],
[4, 5],
[8, 9]]), tensor([[ 2, 3],
[ 6, 7],
[10, 11]]))