flatten用法
简洁版:
都是降维展开操作。不同在于:
numpy
: n d a r r a y . f l a t t e n ( ) ndarray.flatten() ndarray.flatten():
只能返回一维数组。参数是字符串,可设置按行或按列展开。torch
: t o r c h . f l a t t e n ( i n p u t , s t a r t d i m = 0 , e n d d i m = − 1 ) → T e n s o r torch.flatten(input, start_dim=0, end_dim=- 1) → Tensor torch.flatten(input,startdim=0,enddim=−1)→Tensor
可以返回高维张量,可以指定开始、结束展开的维度,即部分展开。参数是维度。
详细版:
叙述了numpy
和 torch
中常见的关于 高维数组 展开操作的一些函数。
numpy
flatten()
:展成一维,产生数据副本ravel()
:展成一维,不产生副本(修改会改变原数据)squeeze()
:压缩维数为 1 的维度
torch
flatten()
:展成低维张量(可以不是一维),产生数据副本squeeze()
:降维, 压缩维数为 1 的维度unsqueeze()
:增加维数为 1 的维度
1. numpy 展平操作
1.1 ndarray.flatten()
- 适用于
numpy
对象
flatten是numpy.ndarray.flatten
的一个函数,即返回一个折叠成一维的数组。
Parameters:
ndarray.flatten(order=‘C’) Return a copy of the array collapsed into one dimension. order : {‘C’, ‘F’, ‘A’, ‘K’}, optional:
- ‘C’ means to flatten in row-major (C-style) order.
- ‘F’ means to flatten in column-major (Fortran- style) order.
- ‘A’ means to flatten in column-major order if a is Fortran contiguous in memory, row-major order otherwise.
- ‘K’ means to flatten a in the order the elements occur in memory.
The default is ‘C’.
- 举例:
参数是 str , ‘C’: 按行展开, ’F‘: 按列展开
>>> import numpy as np
>>> a = np.array([[1,2,3],[4,5,6]])
>>> a
array([[1, 2, 3],
[4, 5, 6]])
>>> a.flatten('C')
array([1, 2, 3, 4, 5, 6])
>>> a.flatten('F')
array([1, 4, 2, 5, 3, 6])
python 中 list 展开:
对list
,使用列表表达式。
参考资料:Python中flatten( ),matrix.A用法说明
1.2 ndarray.ravel()
n d a r r a y . f l a t t e n ( ) ndarray.flatten() ndarray.flatten() 和 n d a r r a y . r a v e l ( ) ndarray.ravel() ndarray.ravel() 都是对向量的展平操作,区别在于:
- n d a r r a y . f l a t t e n ( ) ndarray.flatten() ndarray.flatten() : 返回原数组副本
- n d a r r a y . r a v e l ( ) ndarray.ravel() ndarray.ravel(): 不返回原数组副本
>>> a.flatten()[0] = -1
>>> a
array([[1, 2, 3],
[4, 5, 6]])
>>> a.ravel()[0] = -1
>>> a
array([[-1, 2, 3],
[ 4, 5, 6]])
1.3 ndarray.squeeze()
n d a r r a y . s q u e e z e ( d i m = a ) ndarray.squeeze(dim = a) ndarray.squeeze(dim=a): 对维数为1的维数降维,, 可以指定维数
>>> b = np.array([[[1],[2]]])
>>> b
array([[[1],
[2]]])
>>> b.shape
(1, 2, 1)
>>> b.squeeze()
array([1, 2])
>>> b.squeeze(0)
array([[1],
[2]])
>>> b.squeeze(2)
array([[1, 2]])
2. torch 展平操作
2.1 torch.flatten()
t o r c h . f l a t t e n ( i n p u t , s t a r t d i m = 0 , e n d d i m = − 1 ) → T e n s o r torch.flatten(input, start_dim=0, end_dim=- 1) → Tensor torch.flatten(input,startdim=0,enddim=−1)→Tensor
可以指定开始、结束展开的维度。
如从 dim = 1
开始, shape: (2, 2, 2) -> (2, 4)
>>> t = torch.tensor([[[1, 2],
... [3, 4]],
... [[5, 6],
... [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
官方文档:TORCH.FLATTEN
2.2 torch.squeeze() 和 torch.unsqueeze()
- torch 中也提供
squeeze
函数, 同 numpy
>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x) # 压缩维数为 1 的维度
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0) # 压缩维度 0, 0 维不为 1
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1) # 压缩维度 1, 1 维为 1
>>> y.size()
torch.Size([2, 2, 1, 2])
- 此外, torch 中还提供了
torch.unsqueeze
,扩充维度
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1, 2, 3, 4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
[ 2],
[ 3],
[ 4]])