import torch x = torch.tensor([ [[1,2,3,4],[5,6,7,8],[9,10,11,12]], [[13,14,15,16],[17,18,19,20],[21,22,23,24]] ]).float() print(x.shape)#[2,3,4] # print(x) x=x.reshape(2,3,2,2)#[2,3,2,2] x=x.reshape(2,2,6)#[2,2,6] print(x.shape) 记录一下reshape改变维度