以编辑距离为例,记录下rust加速python的过程
先放完整代码地址
WenmuZhou/rust_python
新建一个工程
cargo new rust_python
cd rust_python
编辑 Cargo.toml文件
在文件中加上
[lib]
name =
"edit_distence_rust"
# 最终生存的so文件命名为 "lib{name}.so"
crate-type =
["dylib"]
[dependencies.cpython]
version =
"*"
features =
["extension-module"]
编辑src/lib.rs文件
在lib.rs文件内写下如下内容
#[macro_use]
extern crate cpython;
use cpython::{PyResult,
Python};
use std::mem;
// 生成可以直接在python中import的so
// 此处的module_name 和 initmodule_name可以随意命名
// PyInit_edit_distence_rust1 这里的edit_distence_rust1要和使用时的so的文件名一致
py_module_initializer!(module_name, initmodule_name,
PyInit_edit_distence_rust,
|py, m|
{
m.add(py,
"__doc__",
"Module documentation string")?;
m.add(py,
"edit_distance", py_fn!(py, edit_distance_py(a:
&str, b:
&str)))?;
Ok(())
});
fn edit_distance_py(_:
Python, a:
&str, b:
&str)
->
PyResult<i32>
{
let result = edit_distance(a,b)
as i32;
Ok(result)
}
pub
fn edit_distance(a:
&str, b:
&str)
-> usize {
let
mut a = a;
let
mut b = b;
let
mut len_a = a.chars().count();
let
mut len_b = b.chars().count();
if len_a < len_b{
mem::swap(&mut a,
&mut b);
mem::swap(&mut len_a,
&mut len_b);
}
// handle special case of 0 length
if len_a ==
0
{
return len_b
}
else
if len_b ==
0
{
return len_a
}
let len_b = len_b +
1;
let
mut pre;
let
mut tmp;
let
mut cur = vec![0; len_b];
// initialize string b
for i in 1..len_b {
cur[i]
= i;
}
// calculate edit distance
for
(i,ca) in a.chars().enumerate()
{
// get first column for this row
pre = cur[0];
cur[0]
= i +
1;
for
(j, cb) in b.chars().enumerate()
{
tmp = cur[j +
1];
cur[j +
1]
= std::cmp::min(
// deletion
tmp +
1, std::cmp::min(
// insertion
cur[j]
+
1,
// match or substitution
pre +
if ca == cb {
0
}
else
{
1
}));
pre = tmp;
}
}
cur[len_b -
1]
}
// 生成单独的so文件,python中使用ctypes调用
use std::os::raw::c_char;
use std::ffi::CStr;
#[no_mangle]
pub
extern
"C"
fn edit_distance_so(a:
*const c_char, b:
*const c_char)
-> usize {
let a =
unsafe
{
CStr::from_ptr(a)
};
let
mut a:
&str =a.to_str().unwrap();
//
let b =
unsafe
{
CStr::from_ptr(b)
};
let
mut b:
&str = b.to_str().unwrap();
//
let
mut len_a = a.chars().count();
let
mut len_b = b.chars().count();
if len_a < len_b{
mem::swap(&mut a,
&mut b);
mem::swap(&mut len_a,
&mut len_b);
}
// handle special case of 0 length
if len_a ==
0
{
return len_b
}
else
if len_b ==
0
{
return len_a
}
let len_b = len_b +
1;
let
mut pre;
let
mut tmp;
let
mut cur = vec![0; len_b];
// initialize string b
for i in 1..len_b {
cur[i]
= i;
}
// calculate edit distance
for
(i,ca) in a.chars().enumerate()
{
// get first column for this row
pre = cur[0];
cur[0]
= i +
1;
for
(j, cb) in b.chars().enumerate()
{
tmp = cur[j +
1];
cur[j +
1]
= std::cmp::min(
// deletion
tmp +
1, std::cmp::min(
// insertion
cur[j]
+
1,
// match or substitution
pre +
if ca == cb {
0
}
else
{
1
}));
pre = tmp;
}
}
cur[len_b -
1]
}
生成python包
cargo build —release
会看到如下输出
拷贝生成的so文件到项目目录下
cp target/release/libedit_distence_rust.so edit_distence_rust.so
```bash
src/lib.rs 里的PyInit_edit_distence_rust中的edit_distence_rust要和这里cp的目标文件名一致,不然会报如下错误(手动将edit_distence_rust.so改为hello.so,然后在python里执行import hello)
编辑test.py
这里使用python的编辑距离包Levenshtein进行结果和速度的对比
```python
import Levenshtein
import time
import edit_distence_rust
print(dir(edit_distence_rust))
tic = time.time()
for i in range(270000):
dis = edit_distence_rust.edit_distance('我的中国心1',
'别人也是调用的底层C文件吧')
print('我的 rust cython so:', time.time()-tic, dis)
import ctypes
so = ctypes.CDLL('edit_distence_rust.so')
tic = time.time()
for i in range(270000):
dis = so.edit_distance_so('我的中国心'.encode(
'utf8'),
'别人也是调用的底层C文件吧'.encode('utf8'))
print('我的 rust ctypes so:', time.time()-tic, dis)
import Levenshtein
tic = time.time()
for i in range(270000):
dis =
Levenshtein.distance('我的中国心1',
'别人也是调用的底层C文件吧')
print('别人的库', time.time()-tic, dis)
class Solution:
def minDistance(self, word1, word2):
l1 = len(word1)
+
1
l2 = len(word2)
+
1
if l2 > l1:
return self.minDistance(word2, word1)
m =
[0]*l2 # 遍历到底i行时m[i]表示s1[:i-1]替换为s2[:j-1]的编辑距离
for i in range(1, l2):
m[i]
= i
p =
0
# 用于存储上一行左上角的值
for i in range(1, l1):
p = m[0]
m[0]
= i
for j in range(1, l2):
tmp = m[j]
# 先将上一行i处的结果存起来
m[j]
= p if word1[i-1]
== word2[j -
1]
else min(m[j-1]
+
1, m[j]
+
1, p +
1)
p = tmp
return m[l2-1]
s =
Solution()
d =
0
tic = time.time()
for i in range(270000):
d = s.minDistance('我的中国心',
'别人也是调用的底层C文件吧')
print('自己的py实现', time.time()-tic, d)
相比纯python实现,可以取得43倍的加速,cython形式的so也比ctypes调用的快一些。
完结。
本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可,转载请附上原文出处链接和本声明。
本文链接地址: