c++语言get:_用C++给R语言加速:Rcpp简单用法

作者:黄天元,复旦大学博士在读,热爱数据科学与开源工具(R),致力于利用数据科学迅速积累行业经验优势和科学知识发现,涉猎内容包括但不限于信息计量、机器学习、数据可视化、应用统计建模、知识图谱等,著有《R语言高效数据处理指南》(《R语言数据高效处理指南》(黄天元)【摘要 书评 试读】- 京东图书)。知乎专栏:R语言数据挖掘。邮箱:huang.tian-yuan@qq.com.欢迎合作交流。

最近要做一个小任务,它的描述非常简单,就是识别向量的变化。比如一个整数序列是“4,4,4,5,5,6,6,5,5,5,4,4”,那么我们要根据数字的连续关系来分组,输出应该是“1,1,1,2,2,3,3,4,4,4,5,5”。这个函数用R写起来非常简单,稍加思考草拟如下:

get_id = function(x){
  z = vector()
  y = NULL
  for (i in seq_along(x)) {
    if(i == 1) y = 1
    else if(x[i] != x[i-1]) y = y + 1
    z = c(z,y)
  }
  z
}

不妨做个小测试:

get_id = function(x){
  z = vector()
  y = NULL
  for (i in seq_along(x)) {
    if(i == 1) y = 1
    else if(x[i] != x[i-1]) y = y + 1
    z = c(z,y)
  }
  z
}
c(rep(33L,3),rep(44L,4),rep(33L,3))
#>  [1] 33 33 33 44 44 44 44 33 33 33
get_id(c(rep(33L,3),rep(44L,4),rep(33L,3)))
#>  [1] 1 1 1 2 2 2 2 3 3 3

得到的结果绝对是准确的,而且按照这些代码,基本可以识别不同的数据类型,只要这些数据能够用“==”来判断是否相同(可能用setequal函数的健壮性更好)。

但是当数据量很大的时候,这样写是否足够快,就很重要了。这意味着看你要等一小时、一天还是一个月。想起自己小时候还学过C++,就希望尝试用Rcpp来加速,拟了代码如下:

library(Rcpp)

# 函数名称为get_id_c
cppFunction('
  IntegerVector get_id_c(IntegerVector x){
  int n = x.size();
  IntegerVector out(n);
  
  for (int i = 0; i < n; i++) {
    if(i == 1) out[i] = 1;
    else if(x[i] == x[i-1]) out[i] = out[i-1];
    else out[i] = out[i-1] + 1;
  }
  return out;
}')

需要声明的是,C++需要定义数据类型,因为任务是正整数,所以函数就接受一个整数向量,输出一个整数向量。多年不用C++,写这么一段代码居然调试过程就出了3次错,惭愧。但是对性能的提升效果非常显著,我们先做一些简单尝试。先尝试1万个整数:

library(pacman)
p_load(tidyfst)  

sys_time_print({
  res1 = get_id(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

4b49a8e5ab3684acd9cd68c93e9727c6.png

0.14s和0.00s的区别,可能体会不够深。那么来10万个整数试试:

sys_time_print({
  res1 = get_id(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

62d0417ab24e802a2ad59c78e53432f0.png

13s vs 0s,有点不能忍了。那么,我们来100万个整数再来试试:

# 不要尝试这个
sys_time_print({
  res1 = get_id(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

# 可以尝试这个
sys_time_print({
  res2 = get_id_c(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

好的,关于这段代码:

sys_time_print({
  res1 = get_id(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

可以不要尝试了,因为直接卡死了。但是如果用Rcpp构造的函数,那么就放心试吧,我们还远远没有探知其算力上限。可以观察一下结果:

cc4158924c4c7da1d62f784736f953fe.png

我们可以看到,1亿个整数,也就是0.69秒;10亿是7.15秒。虽然想尝试百亿,但是我的计算机内存已经不够了。

总结一下,永远不敢说R的速度不够快,只是自己代码写得烂而已(尽管完成了实现,其实get_id这个函数优化的空间是很多的),方法总比问题多。不多说了,去温习C++,学习Rcpp去了。后面如果有闲暇,来做一个Rcpp的学习系列。放一个核心资料链接:

Seamless R and C++ Integration​rcpp.org

根据评论区提示,重新写了R的代码再来比较。代码如下:

library(pacman)
p_load(Rcpp,tidyfst)

get_id = function(x){
  z = vector()
  for (i in seq_along(x)) {
    if(i == 1) z[i] = 1
    else if(x[i] != x[i-1]) z[i] = z[i-1] + 1
    else z[i] = z[i-1]
  }
  z
}

cppFunction('
  IntegerVector get_id_c(IntegerVector x){
  int n = x.size();
  IntegerVector out(n);
  
  for (int i = 0; i < n; i++) {
    if(i == 1) out[i] = 1;
    else if(x[i] == x[i-1]) out[i] = out[i-1];
    else out[i] = out[i-1] + 1;
  }
  return out;
}')

sys_time_print({
  res1 = get_id(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e7),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e7),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e8),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e8),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e9),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e9),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

1万:

fdb6b92e3009d4421f13470079ff30c4.png

10万:

2a17f521467be565c78c0aecc97a0f2d.png

100万:

ea66334632cbb1970a52e807ec8c5946.png

1000万:

a6abfe04f920e6ab41d85d59741383a0.png

1亿:

4d779e0a380b591c2d2dfb869959c94b.png

结论:还是Rcpp香。


三更:R代码提前设置向量长度,再比较。

library(pacman)
p_load(Rcpp,tidyfst)

get_id = function(x){
  z = integer(length(x))
  for (i in seq_along(x)) {
    if(i == 1) z[i] = 1
    else if(x[i] != x[i-1]) z[i] = z[i-1] + 1
    else z[i] = z[i-1]
  }
  z
}

cppFunction('
  IntegerVector get_id_c(IntegerVector x){
  int n = x.size();
  IntegerVector out(n);
  
  for (int i = 0; i < n; i++) {
    if(i == 1) out[i] = 1;
    else if(x[i] == x[i-1]) out[i] = out[i-1];
    else out[i] = out[i-1] + 1;
  }
  return out;
}')

sys_time_print({
  res1 = get_id(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e7),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e7),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

sys_time_print({
  res1 = get_id(c(rep(33L,1e8),rep(44L,4),rep(33L,3)))
})

sys_time_print({
  res2 = get_id_c(c(rep(33L,1e8),rep(44L,4),rep(33L,3)))
})

setequal(res1,res2)

4894e519f56dd4213bdbecc41225dc2e.png

edb0594afa7fc44322474b5065c14432.png

四更:对于这个任务来讲,data.table的rleid函数是最快的。R语言中的终极魔咒,能找到现成的千万不要自己写,总有巨佬在前头。不过直到10亿个整数才有难以忍受的差距。

e755fb4390af7ff62568d4d2c10bb6ba.png
1亿

6d5b340a5711ed4b4faa05f53290ebcc.png
10亿
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值