torch.cat()
举例如下:
import torch
import numpy as np
x=np.array([[1,2,3],[4,5,6]])
y=np.array([[7,8,9],[10,11,12]])
print("=================x=====================")
print(x)
print("=================y=====================")
print(y)
x=torch.from_numpy(x)
y=torch.from_numpy(y)
a=torch.cat((x,y),dim=0)
b=torch.cat((x,y),dim=1)
#c=torch.cat((x,y),dim=2)
print("=================a=====================")
print(a)
print(a.shape)
print("=================b=====================")
print(b)
print(b.shape)
#print(c)
结果:
=================x=====================
[[1 2 3]
[4 5 6]]
=================y=====================
[[ 7 8 9]
[10 11 12]]
=================a=====================
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]], dtype=torch.int32)
torch.Size([4, 3])
=================b=====================
tensor([[ 1, 2, 3, 7, 8, 9],
[ 4, 5, 6, 10, 11, 12]], dtype=torch.int32)
torch.Size([2, 6])
torch.stack()
torch.stack()使用