cython初探

我一直非常喜欢 Python。当人们提到 Python 的时候,经常会说到下面两个优点:

  1. 写起来方便
  2. 容易调用 C/C++ 的库

然而实际上,第一点是以巨慢的执行速度为代价的,而第二点也需要库本身按照 Python 的规范使用 Python API、导出相应的符号。

天壤实习的时候,跟 Cython 打了不少交道,觉得这个工具虽然 Bug 多多,写的时候也有些用户体验不好的地方,但已经能极大提高速度和方便调用 C/C++,还是非常不错的。这里就给大家简单介绍一下 Cython(注意区别于 CPython)。Cython 可以让我们方便地:

  • 用 Python 的语法混合编写 Python 和 C/C++ 代码,提升 Python 速度
  • 调用 C/C++ 代码

例子:矩阵乘法

假设我们现在正在编写一个很简单的矩阵乘法代码,其中矩阵是保存在 numpy.ndarray 中。Python 代码可以这么写:

# dot_python.py
import numpy as np

def naive_dot(a, b):
    if a.shape[1] != b.shape[0]:
        raise ValueError('shape not matched')
    n, p, m = a.shape[0], a.shape[1], b.shape[1]
    c = np.zeros((n, m), dtype=np.float32)
    for i in xrange(n):
        for j in xrange(m):
            s = 0
            for k in xrange(p):
                s += a[i, k] * b[k, j]
            c[i, j] = s
    return c

不用猜也知道这比起 C/C++ 写的要慢的不少。我们感兴趣的是,怎么用 Cython 加速这个程序。我们先上 Cython 程序代码:

# dot_cython.pyx
import numpy as np
cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
cdef np.ndarray[np.float32_t, ndim=2] _naive_dot(np.ndarray[np.float32_t, ndim=2] a, np.ndarray[np.float32_t, ndim=2] b):
    cdef np.ndarray[np.float32_t, ndim=2] c
    cdef int n, p, m
    cdef np.float32_t s
    if a.shape[1] != b.shape[0]:
        raise ValueError('shape not matched')
    n, p, m = a.shape[0], a.shape[1], b.shape[1]
    c = np.zeros((n, m), dtype=np.float32)
    for i in xrange(n):
        for j in xrange(m):
            s = 0
            for k in xrange(p):
                s += a[i, k] * b[k, j]
            c[i, j] = s
    return c

def naive_dot(a, b):
    return _naive_dot(a, b)

可以看到这个程序和 Python 写的几乎差不多。我们来看看不一样部分:

  • Cython 程序的扩展名是 .pyx
  • cimport 是 Cython 中用来引入 .pxd 文件的命令。有关 .pxd 文件,可以简单理解成 C/C++ 中用来写声明的头文件,更具体的我会在后面写到。这里引入的两个是 Cython 预置的。
  • @cython.boundscheck(False) 和 @cython.wraparound(False) 两个修饰符用来关闭 Cython 的边界检查
  • Cython 的函数使用 cdef 定义,并且他可以给所有参数以及返回值指定类型。比方说,我们可以这么编写整数 min 函数:

      cdef int my_min(int x, int y):
          return x if x <= y else y
    

    这里 np.ndarray[np.float32_t, ndim=2] 就是一个类型名就像 int 一样,只是它比较长而且信息量比较大而已。它的意思是,这是个类型为 np.float32_t 的2维 np.ndarray。

  • 在函数体内部,我们一样可以使用 cdef typename varname 这样的语法来声明变量
  • 在 Python 程序中,是看不到 cdef 的函数的,所以我们这里 def naive_dot(a, b) 来调用 cdef 过的 _naive_dot 函数。
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值