Flatten
含义
flatten的中文含义为“扁平化”,具体怎么理解呢?我们可以尝试这么理解,假设你的数据为1维数据,那么这个数据天然就已经扁平化了,如果是2维数据,那么扁平化就是将2维数据变为1维数据,如果是3维数据,那么就要根据你自己所选择的“扁平化程度”来进行操作,假设需要全部扁平化,那么就直接将3维数据变为1维数据,如果只需要部分扁平化,那么有一维的数据不会进行扁平操作。
操作
torch.flatten()方法有三个参数,分别:
- input tensor:该方法的输入
- start_dim:开始flatten的维度
- end_dim:结束flatten的维度
实践
1、全部扁平化
import numpy as np
import torch
x = np.arange(16)
x = np.reshape(x, (4, 4))
x = torch.from_numpy(x)
print('before flatten', x)
x = torch.flatten(x) # 默认扁平化程度为最高
print('after flatten', x)
==============result======================
before flatten: tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
after flatten: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
- 部分扁平化
-
import numpy as np import torch x = np.arange(27) x = np.reshape(x, (3,3,3)) x = torch.from_numpy(x) print('before flatten', x) x = torch.flatten(x, start_dim=0, end_dim=1) # 默认扁平化程度为最高 print('after flatten', x) ==============result====================== before flatten: tensor([[[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8]], [[ 9, 10, 11], [12, 13, 14], [15, 16, 17]], [[18, 19, 20], [21, 22, 23], [24, 25, 26]]]) after flatten: tensor([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11], [12, 13, 14], [15, 16, 17], [18, 19, 20], [21, 22, 23], [24, 25, 26]])
import numpy as np import torch x = np.arange(27) x = np.reshape(x, (3,3,3)) x = torch.from_numpy(x) print('before flatten', x) x = torch.flatten(x, start_dim=1, end_dim=2) # 默认扁平化程度为最高 print('after flatten', x) ==============result====================== before flatten: tensor([[[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8]], [[ 9, 10, 11], [12, 13, 14], [15, 16, 17]], [[18, 19, 20], [21, 22, 23], [24, 25, 26]]]) after flatten: tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8], [ 9, 10, 11, 12, 13, 14, 15, 16, 17], [18, 19, 20, 21, 22, 23, 24, 25, 26]])