import torch
if __name__ == '__main__':
a = torch.tensor([ [ [1,2,3,2,1],
[2,5,6,3,6] ],
[ [2,1,5,9,8],
[4,6,8,1,1] ] ],dtype=torch.float)
a = torch.reshape(a,(-1,5))
print(a)
tensor([[1., 2., 3., 2., 1.],
[2., 5., 6., 3., 6.],
[2., 1., 5., 9., 8.],
[4., 6., 8., 1., 1.]])