flatten(a, 1)代码
import torch
a = torch.arange(120).reshape(2,3,4,5)
print(a.shape)
a = torch.flatten(a, 1)
print(a.shape)
结果:
torch.Size([2, 3, 4, 5])
torch.Size([2, 60])
flatten(a, 2)代码
import torch
a = torch.arange(120).reshape(2,3,4,5)
print(a.shape)
a = torch.flatten(a, 2)
print(a.shape)
结果
torch.Size([2, 3, 4, 5])
torch.Size([2, 3, 20])
- python里的flatten(a,dim)表示,从第dim个维度开始展开,将后面的维度转化为一维.也就是说,只保留dim之前的维度,其他维度的数据全都挤在dim这一维。
- 比如flatten(a,1),就是只保留第0维,第1和第2维的数据全部被总和在一维