generic_filter简介
scipy原生的滤波器可能不能满足我们的要求。比如希望计算5*5邻域内,值小于中心像素值的像素的平均值,原生的滤波器就无能为力。
这时候就需要用到generic_filter
来自定义滤波函数。
import numpy as np
from scipy.ndimage import generic_filter
window_size = 5
def custom_func(window):
center_pixel = window[window_size**2//2] # 传入的窗口被展平为一维
smaller_pixels = window[window<center_pixel]
mean_value = smaller_pixels.mean()
return mean_value
example_arr = np.arange(225).reshape(15,15)
result = generic_filter(example_arr, custom_func, window_size)
在上面的例子中,我们定义了一个可以计算5*5邻域内,值小于中心像素值的像素的平均值的函数custom_func
。
调用generic_filter
时,generic_filter
将遍历example_arr
中的每个元素的5 * 5邻域,将其展平为一维、传入custom_func
,返回值作为中心元素对应的计算结果。如此,我们实现了自定义滤波器。
generic_filter
的详细参数可以查阅文档,这边不赘述。
使用C拓展加速generic_filter
generic_filter
的本质是在Python中显式地遍历输入数据的每个元素的邻域并调用custom_func
。众所周知,Python的循环非常慢,使用了这种方式的generic_filter
的效率自然也十分低下,远远不如原生的滤波器。
好在scipy提供了性能改善的方法,可以通过往generic_filter
中传入scipy.LowLevelCallable
类型作为滤波函数,代替直接在Python中定义的滤波函数。scipy.LowLevelCallable
可以有多种获得方式,包括使用numba
装饰函数、在Cython中定义、使用ctypes
引入动态链接库等。传入scipy.LowLevelCallable
时,generic_filter
不再通过循环调用滤波函数,而是先将其编译,效率得到较大改善。
官方文档没有给出使用scipy.LowLevelCallable
加速generic_filter的案例,只给了文档说明,这边给出加速流程和示例代码。
使用C语言编写滤波函数
首先编写在testlib.c中编写滤波函数代码。首先要注意的是,滤波函数的签名是规定好了的,只能是:
int callback(double *buffer, long filter_size,
double *return_value, void *user_data)
(好像还可以是其他的,详见官方文档,但是我没有尝试成功)
其中buffer
是传入的一维数组指针,对应窗口内的数据;filter_size
是在generic_filter
的参数中确定的窗口大小的平方,即窗口的元素个数;return_value
是用于返回结果的指针;user_data
是自定义的传入数据的指针(用法详见官方文档,这边没用到)。
完整的C代码如下:
int callback(double *buffer, long filter_size,
double *return_value, void *user_data)
{
int center_idx = (int) filter_size/2;
float center_ele = (float) buffer[center_idx]; //获取中心元素值
float sum = 0;
float smaller_count = 0;
for(int i = 0; i < filter_size; i++)
{
if (buffer[i] <= center_ele){
smaller_count ++;
sum += buffer[i];
}
}
return_value[0] = (double) sum/smaller_count; //return_value[0]是计算结果
return 1; //使用指针返回结果,返回值无所谓
}
编译C代码为.so或.dll文件
Linux
运行以下命令进行编译,得到testlib.so文件
gcc -shared -fPIC -o testlib.so testlib.c
Windows
用编译器编译可以得到testlib.dll文件
在Python中加载并调用编译好的函数
import os, ctypes
import numpy as np
from scipy import LowLevelCallable
from scipy.ndimage import generic_filter
# 加载.so,定义其ctypes类型
lib = ctypes.CDLL(os.path.abspath('testlib.so'))
lib.callback.restype = ctypes.c_int
lib.callback.argtypes = (ctypes.POINTER(ctypes.c_double),
ctypes.c_long,
ctypes.POINTER(ctypes.c_double),
ctypes.c_void_p)
# 将函数包装为scipy.LowLevelCallable对象
func = LowLevelCallable(lib.callback)
ii = np.arange(81).reshape(9,9).astype(np.float64)
# 将数据和scipy.LowLevelCallable传入,计算结果
res2 = generic_filter(ii, func, 3)
print(res2)