函数原型:
numpy.apply_along_axis(func1d, axis, arr, *args, **kwargs)
这个函数的各个参数的含义请大家自行查阅了解,这里只记录一下参数axis
和*args
的作用。
参数axis
首先定义一个函数来打印数组的shape
与内容,然后生成一个4x3的数组array
:
def print_a(array):
print("shape: %s, array is \n%s"%(array.shape, array))
array=np.array([[2,3,4],[3,4,5],[6,7,8],[7,8,9]])
对这个数组调用np.apply_along_axis
函数,看该函数会对数组产生什么作用
tmp=np.apply_along_axis(print_a, 0, array) # 在0维上调用
结果如下:
shape: (4,), array is
[2 3 6 7]
shape: (4,), array is
[3 4 7 8]
shape: (4,), array is
[4 5 8 9]
可以看到,数组array
被切成了3部分,每一部分是array
在维度0上的元素,也就是说,apply_along_axis
的功能是将数组在指定维度上的元素聚合起来,聚合的结果是一个向量,array
的维度是4x3
,如果将第0维聚合起来的话,每一个向量的长度为4,如果将第1维聚合起来,那么每一个向量的维度是3;
下面是在维度1上调用该函数的结果
tmp=np.apply_along_axis(print_a, 1, array)
结果如下:
shape: (3,), array is
[2 3 4]
shape: (3,), array is
[3 4 5]
shape: (3,), array is
[6 7 8]
shape: (3,), array is
[7 8 9]
参数args
args
指这个函数可以接受很多其他参数,这些参数要传入函数func1d
中进行使用,需要注意的是,除了第3个参数arr
之外,后面传入的参数不会被按照维度切分,就是说后面传入的数组会原封不动的传入到函数func1d
中。
示例代码:
def print_b(arr, brr):
print(arr.shape, brr.shape)
arr=np.array([[2,3],[3,4]]) # (2,2)
brr=np.array([[2,3,4],[3,4,5],[6,7,8],[7,8,9]]) # (4,3)
# 调用apply_along_axis函数
tmp=np.apply_along_axis(print_b,0,arr,brr)
结果如下:
tmp=np.apply_along_axis(print_b,0,arr,brr)
(2,) (4, 3)
(2,) (4, 3)
可以看到,brr的维度没有变化。