sum函数中的axis
相当于去除axis指定的那个维度。
例如:a维度是(2,3,4),则a.sum(axis=0) 结果为(3,4)
a.sum(axis=1)结果为(2,4)。
即除了axis指定的那个维度外,其它维度索引值相同就累加。
def dfs_index(a, dim, index, outlist):
'''
get all indexs for a
'''
if dim >= a.ndim:
#print(index)
outlist.append(index.copy())
return
for val in range(a.shape[dim]):
index[dim] = val
dfs_index(a, dim+1, index, outlist)
def sum_axis(a, axis=0):
'''
imitate numpy sum
'''
shape_list = list(a.shape)
# delete the item at axis
shape_list.pop(axis)
acc_sum = np.zeros(tuple(shape_list), dtype=a.dtype)
out = []
dfs_index(a, 0, [0]*a.ndim, out)
for ind in out:
#print("ind", ind)
sum_ind = ind.copy()
sum_ind.pop(axis)
#print("sum_ind", sum_ind)
acc_sum[tuple(sum_ind)] += a[tuple(ind)]
return acc_sum
>>> a = np.arange(6).reshape(2,3)
array([[0, 1, 2],
[3, 4, 5]])
>>> a.sum(0)
array([3, 5, 7])
>>> sum_axis(a, 0)
array([3, 5, 7])
>>> a.sum(1)
array([ 3, 12])
>>> sum_axis(a, 1)
array([ 3, 12])
# test case 2
>>> b = np.arange(24).reshape(3,2,4)
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> b.sum(0)
array([[24, 27, 30, 33],
[36, 39, 42, 45]])
>>> sum_axis(b, 0)
array([[24, 27, 30, 33],
[36, 39, 42, 45]])
>>> b.sum(1)
array([[ 4, 6, 8, 10],
[20, 22, 24, 26],
[36, 38, 40, 42]])
>>> sum_axis(b, 1)
array([[ 4, 6, 8, 10],
[20, 22, 24, 26],
[36, 38, 40, 42]])
>>> b.sum(2)
array([[ 6, 22],
[38, 54],
[70, 86]])
>>> sum_axis(b, 2)
array([[ 6, 22],
[38, 54],
[70, 86]])
... other test case