- flatten()是对多维数据的降维函数。
- flatten(),默认缺省参数为0,也就是说flatten()和flatte(0)效果一样。
- python里的flatten(dim)表示,从第dim个维度开始展开,将后面的维度转化为一维.也就是说,只保留dim之前的维度,其他维度的数据全都挤在dim这一维。
比如我们随机定义一个维度为(2,3,4)的数据a
import torch
a = torch.rand(2,3,4)
a输出结果为:
a此时的维度为(2,3,4)
import torch
a = torch.rand(2,3,4)
a输出结果为:
a此时的维度为(2,3,4)