语法与定义
numpy.squeeze(a,axis = None),其中:
- a:输入的数组;
- axis:需要删除的维度,但是指定的维度必须为单维度,否则将会报错(xis的取值可为None 或 int 或 tuple of ints, 可选。若axis为空,则删除所有单维度的条目);
- 会将运行结果已数组形式返回,并且不改变原数组。
深度学习中的常见用途
可以使用该函数将表示向量的数组转换为秩为1的数组,这时再利用matplotlib库函数画图就可以正常的显示结果了。
举例
例1
import numpy as np
a = np.arange(10).reshape(1,10)
a
array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
(1, 10)
b = np.squeeze(a)
b
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
b.shape
(10,)
例2
d = np.arange(10).reshape(1,2,5)
d
array([[[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]]])
d.shape
(1, 2, 5)
np.squeeze(d)
array([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
np.squeeze(d).shape
(2, 5)
np.squeeze(d,axis=1)
ValueError: cannot select an axis to squeeze out which has size not equal to one
由以上结果可知,当指定维度值不为1时,报错。
例3
e = np.arange(10).reshape(1,10,1)
e
np.squeeze(e)
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
np.squeeze(e).shape
(10,)
由例3可知,axis为默认值时,np.squeeze()函数会将所有单维度删除生成新的数组。
参考
https://blog.csdn.net/zenghaitao0128/article/details/78512715