如您所知,黄线意味着与python的一些交互发生,即使用python功能而不是原始的c功能,您可以查看生成的代码,看看会发生什么,以及是否可以/应该修复/避免。在
并不是每次与python的交互都意味着(可测量的)减速。在
我们来看看这个简化的函数:%%cython
cimport numpy as np
def use_slices(np.ndarray[np.double_t] a):
a[0:len(a)]=0.0
当我们查看生成的代码时,我们看到(我只保留了重要部分):
^{pr2}$
所以基本上我们得到一个新的片段(它是一个numpy数组),然后使用numpy的功能(PyObject_SetItem)将所有元素设置为0.0,这是隐藏在引擎盖下的C代码。在
让我们看看手写for loop的版本:cimport numpy as np
def use_for(np.ndarray[np.double_t] a):
cdef int i
for i in range(len(a)):
a[i]=0.0
它仍然使用PyObject_Length(因为length)和绑定检查,但其他情况下它是C代码。当我们比较时间时:>>> import numpy as np
>>> a=np.ones((500,))
>>> %timeit use_slices(a)
100000 loops, best of 3: 1.85 µs per loop
>>> %timeit use_for(a)
1000000 loops, best of 3: 1.42 µs per loop
>>> b=np.ones((250000,))
>>> %timeit use_slices(b)
10000 loops, best of 3: 104 µs per loop
>>> %timeit use_for(b)
1000 loops, best of 3: 364 µs per loop
您可以看到为小尺寸创建切片的额外开销,但是for版本中的额外检查意味着从长远来看它会有更多的开销。在
禁用这些检查:%%cython
cimport cython
cimport numpy as np
@cython.boundscheck(False)
@cython.wraparound(False)
def use_for_no_checks(np.ndarray[np.double_t] a):
cdef int i
for i in range(len(a)):
a[i]=0.0
在生成的html中,我们可以看到a[i]变得非常简单:__pyx_t_3 = __pyx_v_i;
*__Pyx_BufPtrStrided1d(__pyx_t_5numpy_double_t *, __pyx_pybuffernd_a.rcbuffer->pybuffer.buf, __pyx_t_3, __pyx_pybuffernd_a.diminfo[0].strides) = 0.0;
}
__Pyx_BufPtrStrided1d(type, buf, i0, s0)是为(type)((char*)buf + i0 * s0)定义的。
现在:>>> %timeit use_for_no_checks(a)
1000000 loops, best of 3: 1.17 µs per loop
>>> %timeit use_for_no_checks(b)
1000 loops, best of 3: 246 µs per loop
我们可以通过在for循环中释放gil来进一步改进它:%%cython
cimport cython
cimport numpy as np
@cython.boundscheck(False)
@cython.wraparound(False)
def use_for_no_checks_no_gil(np.ndarray[np.double_t] a):
cdef int i
cdef int n=len(a)
with nogil:
for i in range(n):
a[i]=0.0
现在:>>> %timeit use_for_no_checks_no_gil(a)
1000000 loops, best of 3: 1.07 µs per loop
>>> %timeit use_for_no_checks_no_gil(b)
10000 loops, best of 3: 166 µs per loop
所以它有点快,但对于更大的阵列,您仍然无法击败numpy。在
在我看来,有两件事可以借鉴:Cython不会通过for循环将切片转换为访问,因此必须使用Python功能。在
它的开销很小,但它只是调用numpy功能,大部分工作都是在numpy代码中完成的,这不能通过Cython来加速。在
最后一次尝试使用memset函数:%%cython
from libc.string cimport memset
cimport numpy as np
def use_memset(np.ndarray[np.double_t] a):
memset(&a[0], 0, len(a)*sizeof(np.double_t))
我们得到:>>> %timeit use_memset(a)
1000000 loops, best of 3: 821 ns per loop
>>> %timeit use_memset(b)
10000 loops, best of 3: 102 µs per loop
对于大型数组,它的速度也与numpy代码一样快。在
正如DavidW建议的那样,可以尝试使用内存视图:%%cython
cimport numpy as np
def use_slices_memview(double[::1] a):
a[0:len(a)]=0.0
对于小数组,会产生稍微快一点的代码,但对于大数组,则会产生类似的快速代码(与numpy slices相比):>>> %timeit use_slices_memview(a)
1000000 loops, best of 3: 1.52 µs per loop
>>> %timeit use_slices_memview(b)
10000 loops, best of 3: 105 µs per loop
这意味着,内存视图切片的开销比numpy切片小。以下是生成的代码:__pyx_t_1 = __Pyx_MemoryView_Len(__pyx_v_a);
__pyx_t_2.data = __pyx_v_a.data;
__pyx_t_2.memview = __pyx_v_a.memview;
__PYX_INC_MEMVIEW(&__pyx_t_2, 0);
__pyx_t_3 = -1;
if (unlikely(__pyx_memoryview_slice_memviewslice(
&__pyx_t_2,
__pyx_v_a.shape[0], __pyx_v_a.strides[0], __pyx_v_a.suboffsets[0],
0,
0,
&__pyx_t_3,
0,
__pyx_t_1,
0,
1,
1,
0,
1) < 0))
{
__PYX_ERR(0, 27, __pyx_L1_error)
}
{
double __pyx_temp_scalar = 0.0;
{
Py_ssize_t __pyx_temp_extent = __pyx_t_2.shape[0];
Py_ssize_t __pyx_temp_idx;
double *__pyx_temp_pointer = (double *) __pyx_t_2.data;
for (__pyx_temp_idx = 0; __pyx_temp_idx < __pyx_temp_extent; __pyx_temp_idx++) {
*((double *) __pyx_temp_pointer) = __pyx_temp_scalar;
__pyx_temp_pointer += 1;
}
}
}
__PYX_XDEC_MEMVIEW(&__pyx_t_2, 1);
__pyx_t_2.memview = NULL;
__pyx_t_2.data = NULL;
我认为最重要的部分是:这段代码不创建额外的临时对象——它重用切片的现有内存视图。在
如果使用内存视图,我的编译器生成(至少对于我的机器)稍微快一点的代码。不知道是否值得调查。乍一看,每个迭代步骤的区别是:# created code for memview-slices:
*((double *) __pyx_temp_pointer) = __pyx_temp_scalar;
__pyx_temp_pointer += 1;
#created code for memview-for-loop:
__pyx_v_i = __pyx_t_3;
__pyx_t_4 = __pyx_v_i;
*((double *) ( /* dim=0 */ ((char *) (((double *) data) + __pyx_t_4)) )) = 0.0;
我希望不同的编译器能以不同的方式处理这些代码。但显然,第一个版本更容易优化。在
正如Behzad Jamali所指出的,double[:] a和{}之间有区别。使用切片的第二个版本在我的机器上大约快了20%。区别在于,在编译期间,double[::1]版本的内存访问是连续的,这可以用于优化。在带有double[:]的版本中,我们直到运行时才知道任何有关stride的信息。在