目录
pytorch 2维度切分和拼接
第二个参数是切块大小:
import torch
if __name__ == '__main__':
data = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
y = torch.split(data, 1, dim=1) # 按照4这个维度去分,每大块包含2个小块
for index, i in enumerate(y):
print('index',index,i.size())
print(i)
pytorch拼接:
c1 = torch.cat((y[0],y[1],y[2]), dim=0)
print(c1)
c3 = torch.stack((y[0],y[1],y[2]), dim=0)
print(c1)