项目使用encode_使用rust加速python

以编辑距离为例,记录下rust加速python的过程
先放完整代码地址
WenmuZhou/rust_python
新建一个工程

  1. cargo new rust_python
  2. cd rust_python


编辑 Cargo.toml文件
在文件中加上

  1. [lib]
  2. name = "edit_distence_rust" # 最终生存的so文件命名为 "lib{name}.so"
  3. crate-type = ["dylib"]
  4. [dependencies.cpython]
  5. version = "*"
  6. features = ["extension-module"]


编辑src/lib.rs文件
在lib.rs文件内写下如下内容

  1. #[macro_use] extern crate cpython;
  2. use cpython::{PyResult, Python};
  3. use std::mem;
  4. // 生成可以直接在python中import的so
  5. // 此处的module_name 和 initmodule_name可以随意命名
  6. // PyInit_edit_distence_rust1 这里的edit_distence_rust1要和使用时的so的文件名一致
  7. py_module_initializer!(module_name, initmodule_name, PyInit_edit_distence_rust, |py, m| {
  8. m.add(py, "__doc__", "Module documentation string")?;
  9. m.add(py, "edit_distance", py_fn!(py, edit_distance_py(a: &str, b: &str)))?;
  10. Ok(())
  11. });
  12. fn edit_distance_py(_: Python, a: &str, b: &str) -> PyResult<i32> {
  13. let result = edit_distance(a,b) as i32;
  14. Ok(result)
  15. }
  16. pub fn edit_distance(a: &str, b: &str) -> usize {
  17. let mut a = a;
  18. let mut b = b;
  19. let mut len_a = a.chars().count();
  20. let mut len_b = b.chars().count();
  21. if len_a < len_b{
  22. mem::swap(&mut a, &mut b);
  23. mem::swap(&mut len_a, &mut len_b);
  24. }
  25. // handle special case of 0 length
  26. if len_a == 0 {
  27. return len_b
  28. } else if len_b == 0 {
  29. return len_a
  30. }
  31. let len_b = len_b + 1;
  32. let mut pre;
  33. let mut tmp;
  34. let mut cur = vec![0; len_b];
  35. // initialize string b
  36. for i in 1..len_b {
  37. cur[i] = i;
  38. }
  39. // calculate edit distance
  40. for (i,ca) in a.chars().enumerate() {
  41. // get first column for this row
  42. pre = cur[0];
  43. cur[0] = i + 1;
  44. for (j, cb) in b.chars().enumerate() {
  45. tmp = cur[j + 1];
  46. cur[j + 1] = std::cmp::min(
  47. // deletion
  48. tmp + 1, std::cmp::min(
  49. // insertion
  50. cur[j] + 1,
  51. // match or substitution
  52. pre + if ca == cb { 0 } else { 1 }));
  53. pre = tmp;
  54. }
  55. }
  56. cur[len_b - 1]
  57. }
  58. // 生成单独的so文件,python中使用ctypes调用
  59. use std::os::raw::c_char;
  60. use std::ffi::CStr;
  61. #[no_mangle]
  62. pub extern "C" fn edit_distance_so(a: *const c_char, b: *const c_char) -> usize {
  63. let a = unsafe { CStr::from_ptr(a) };
  64. let mut a: &str =a.to_str().unwrap(); //
  65. let b = unsafe { CStr::from_ptr(b) };
  66. let mut b: &str = b.to_str().unwrap(); //
  67. let mut len_a = a.chars().count();
  68. let mut len_b = b.chars().count();
  69. if len_a < len_b{
  70. mem::swap(&mut a, &mut b);
  71. mem::swap(&mut len_a, &mut len_b);
  72. }
  73. // handle special case of 0 length
  74. if len_a == 0 {
  75. return len_b
  76. } else if len_b == 0 {
  77. return len_a
  78. }
  79. let len_b = len_b + 1;
  80. let mut pre;
  81. let mut tmp;
  82. let mut cur = vec![0; len_b];
  83. // initialize string b
  84. for i in 1..len_b {
  85. cur[i] = i;
  86. }
  87. // calculate edit distance
  88. for (i,ca) in a.chars().enumerate() {
  89. // get first column for this row
  90. pre = cur[0];
  91. cur[0] = i + 1;
  92. for (j, cb) in b.chars().enumerate() {
  93. tmp = cur[j + 1];
  94. cur[j + 1] = std::cmp::min(
  95. // deletion
  96. tmp + 1, std::cmp::min(
  97. // insertion
  98. cur[j] + 1,
  99. // match or substitution
  100. pre + if ca == cb { 0 } else { 1 }));
  101. pre = tmp;
  102. }
  103. }
  104. cur[len_b - 1]
  105. }


生成python包
cargo build —release
会看到如下输出
拷贝生成的so文件到项目目录下

  1. cp target/release/libedit_distence_rust.so edit_distence_rust.so
  2. ```bash
  3. src/lib.rs 里的PyInit_edit_distence_rust中的edit_distence_rust要和这里cp的目标文件名一致,不然会报如下错误(手动将edit_distence_rust.so改为hello.so,然后在python里执行import hello)
  4. 编辑test.py
  5. 这里使用python的编辑距离包Levenshtein进行结果和速度的对比
  6. ```python
  7. import Levenshtein
  8. import time
  9. import edit_distence_rust
  10. print(dir(edit_distence_rust))
  11. tic = time.time()
  12. for i in range(270000):
  13. dis = edit_distence_rust.edit_distance('我的中国心1', '别人也是调用的底层C文件吧')
  14. print('我的 rust cython so:', time.time()-tic, dis)
  15. import ctypes
  16. so = ctypes.CDLL('edit_distence_rust.so')
  17. tic = time.time()
  18. for i in range(270000):
  19. dis = so.edit_distance_so('我的中国心'.encode(
  20. 'utf8'), '别人也是调用的底层C文件吧'.encode('utf8'))
  21. print('我的 rust ctypes so:', time.time()-tic, dis)
  22. import Levenshtein
  23. tic = time.time()
  24. for i in range(270000):
  25. dis = Levenshtein.distance('我的中国心1', '别人也是调用的底层C文件吧')
  26. print('别人的库', time.time()-tic, dis)
  27. class Solution:
  28. def minDistance(self, word1, word2):
  29. l1 = len(word1) + 1
  30. l2 = len(word2) + 1
  31. if l2 > l1:
  32. return self.minDistance(word2, word1)
  33. m = [0]*l2 # 遍历到底i行时m[i]表示s1[:i-1]替换为s2[:j-1]的编辑距离
  34. for i in range(1, l2):
  35. m[i] = i
  36. p = 0 # 用于存储上一行左上角的值
  37. for i in range(1, l1):
  38. p = m[0]
  39. m[0] = i
  40. for j in range(1, l2):
  41. tmp = m[j] # 先将上一行i处的结果存起来
  42. m[j] = p if word1[i-1] == word2[j -
  43. 1] else min(m[j-1] + 1, m[j] + 1, p + 1)
  44. p = tmp
  45. return m[l2-1]
  46. s = Solution()
  47. d = 0
  48. tic = time.time()
  49. for i in range(270000):
  50. d = s.minDistance('我的中国心', '别人也是调用的底层C文件吧')
  51. print('自己的py实现', time.time()-tic, d)


相比纯python实现,可以取得43倍的加速,cython形式的so也比ctypes调用的快一些。
完结。

本作品采用知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议进行许可,转载请附上原文出处链接和本声明。
本文链接地址:

FlyAI-AI竞赛服务平台​www.flyai.com
63451c07fbfaa4a52c3ac3d7f698a326.png
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值