在numpy的许多函数中,会出现'keepdims'参数,以numpy.sum()为例:
官方文档中给出的解释:
numpy.sum(a, axis=None, dtype=None, out=None, keepdims=<no value>, initial=<no value>, where=<no value>)
'''
keepdimsbool, optional
If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.
If the default value is passed, then keepdims will not be passed through to the sum method of sub-classes of ndarray, however any non-default value will be. If the sub-class’ method does not implement keepdims any exceptions will be raised.
'''
看的一脸懵,还是跑个代码来得实在:
a = np.array([[0, 0, 0],
[0, 1, 0],
[0, 2, 0],
[1, 0, 0],
[1, 1, 0]])
print(a)
'''
输出:
[[0 0 0]
[0 1 0]
[0 2 0]
[1 0 0]
[1 1 0]]
'''
a_sum_true = np.sum(a, keepdims=True)
print(a_sum_true)
print(a_sum_true.shape)
a_sum_false = np.sum(a, keepdims=False)
print(a_sum_false)
print(a_sum_false.shape)
'''
输出:
[[6]]
(1, 1)
6
()
'''
a_sum_axis1_true = np.sum(a, axis=1, keepdims=True)
print(a_sum_axis1_true)
print(a_sum_axis1_true.shape)
a_sum_axis1_false = np.sum(a, axis=1, keepdims=False)
print(a_sum_axis1_false)
print(a_sum_axis1_false.shape)
'''
输出:
[[0]
[1]
[2]
[1]
[2]]
(5, 1)
[0 1 2 1 2]
(5,)
'''
a_sum_axis0_true = np.sum(a, axis=0, keepdims=True)
print(a_sum_axis0_true)
print(a_sum_axis0_true.shape)
a_sum_axis0_false = np.sum(a, axis=0, keepdims=False)
print(a_sum_axis0_false)
print(a_sum_axis0_false.shape)
'''
输出:
[[2 4 0]]
(1, 3)
[2 4 0]
(3,)
'''
如果并不指定'axis'参数,输出的结果是相同的,区别在于当' keepdims = True
'时,输出的是2D结果。
如果指定'axis'参数,输出的结果也是相同的,区别在于'keepdims = True
'时,输出的是2D结果。
可以理解为'keepdims = True
'参数是为了保持结果的维度与原始array相同。
参考: