torch.stack(), torch.cat()用法详解
if __name__ == '__main__':
import torch
x_dat = torch.tensor([[1, 2], [3,4], [5,6]], dtype=torch.float)
y_dat = torch.tensor([[10, 20], [30,40], [50,60]], dtype=torch.float)
res=torch.stack((x_dat,y_dat),0)
print(res)
res = torch.stack((x_dat, y_dat), 1)
print(res)
res = torch.stack((x_dat, y_dat),2)
print(res)
res = torch.cat((x_dat, y_dat), 0)
print(res)
res = torch.cat((x_dat, y_dat), 1)
print(res)
res = torch.cat((x_dat, y_dat), 2)
print(res)
stack 是合并,但是内容单元不变。
cat是追加,内容尺寸会变化