作者:黄天元,复旦大学博士在读,热爱数据科学与开源工具(R),致力于利用数据科学迅速积累行业经验优势和科学知识发现,涉猎内容包括但不限于信息计量、机器学习、数据可视化、应用统计建模、知识图谱等,著有《R语言高效数据处理指南》(《R语言数据高效处理指南》(黄天元)【摘要 书评 试读】- 京东图书)。知乎专栏:R语言数据挖掘。邮箱:huang.tian-yuan@qq.com.欢迎合作交流。
最近要做一个小任务,它的描述非常简单,就是识别向量的变化。比如一个整数序列是“4,4,4,5,5,6,6,5,5,5,4,4”,那么我们要根据数字的连续关系来分组,输出应该是“1,1,1,2,2,3,3,4,4,4,5,5”。这个函数用R写起来非常简单,稍加思考草拟如下:
get_id = function(x){
z = vector()
y = NULL
for (i in seq_along(x)) {
if(i == 1) y = 1
else if(x[i] != x[i-1]) y = y + 1
z = c(z,y)
}
z
}
不妨做个小测试:
get_id = function(x){
z = vector()
y = NULL
for (i in seq_along(x)) {
if(i == 1) y = 1
else if(x[i] != x[i-1]) y = y + 1
z = c(z,y)
}
z
}
c(rep(33L,3),rep(44L,4),rep(33L,3))
#> [1] 33 33 33 44 44 44 44 33 33 33
get_id(c(rep(33L,3),rep(44L,4),rep(33L,3)))
#> [1] 1 1 1 2 2 2 2 3 3 3
得到的结果绝对是准确的,而且按照这些代码,基本可以识别不同的数据类型,只要这些数据能够用“==”来判断是否相同(可能用setequal函数的健壮性更好)。
但是当数据量很大的时候,这样写是否足够快,就很重要了。这意味着看你要等一小时、一天还是一个月。想起自己小时候还学过C++,就希望尝试用Rcpp来加速,拟了代码如下:
library(Rcpp)
# 函数名称为get_id_c
cppFunction('
IntegerVector get_id_c(IntegerVector x){
int n = x.size();
IntegerVector out(n);
for (int i = 0; i < n; i++) {
if(i == 1) out[i] = 1;
else if(x[i] == x[i-1]) out[i] = out[i-1];
else out[i] = out[i-1] + 1;
}
return out;
}')
需要声明的是,C++需要定义数据类型,因为任务是正整数,所以函数就接受一个整数向量,输出一个整数向量。多年不用C++,写这么一段代码居然调试过程就出了3次错,惭愧。但是对性能的提升效果非常显著,我们先做一些简单尝试。先尝试1万个整数:
library(pacman)
p_load(tidyfst)
sys_time_print({
res1 = get_id(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})
sys_time_print({
res2 = get_id_c(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
![4b49a8e5ab3684acd9cd68c93e9727c6.png](https://i-blog.csdnimg.cn/blog_migrate/068be04a4315ebd1bcf19a00bdd5fe09.png)
0.14s和0.00s的区别,可能体会不够深。那么来10万个整数试试:
sys_time_print({
res1 = get_id(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})
sys_time_print({
res2 = get_id_c(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
![62d0417ab24e802a2ad59c78e53432f0.png](https://i-blog.csdnimg.cn/blog_migrate/08b4889177da80c9aa1dd86f4d8f6190.png)
13s vs 0s,有点不能忍了。那么,我们来100万个整数再来试试:
# 不要尝试这个
sys_time_print({
res1 = get_id(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})
# 可以尝试这个
sys_time_print({
res2 = get_id_c(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
好的,关于这段代码:
sys_time_print({
res1 = get_id(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})
可以不要尝试了,因为直接卡死了。但是如果用Rcpp构造的函数,那么就放心试吧,我们还远远没有探知其算力上限。可以观察一下结果:
![cc4158924c4c7da1d62f784736f953fe.png](https://i-blog.csdnimg.cn/blog_migrate/e729ff98a92766f864140f102d44cf37.jpeg)
我们可以看到,1亿个整数,也就是0.69秒;10亿是7.15秒。虽然想尝试百亿,但是我的计算机内存已经不够了。
总结一下,永远不敢说R的速度不够快,只是自己代码写得烂而已(尽管完成了实现,其实get_id这个函数优化的空间是很多的),方法总比问题多。不多说了,去温习C++,学习Rcpp去了。后面如果有闲暇,来做一个Rcpp的学习系列。放一个核心资料链接:
Seamless R and C++ Integrationrcpp.org根据评论区提示,重新写了R的代码再来比较。代码如下:
library(pacman)
p_load(Rcpp,tidyfst)
get_id = function(x){
z = vector()
for (i in seq_along(x)) {
if(i == 1) z[i] = 1
else if(x[i] != x[i-1]) z[i] = z[i-1] + 1
else z[i] = z[i-1]
}
z
}
cppFunction('
IntegerVector get_id_c(IntegerVector x){
int n = x.size();
IntegerVector out(n);
for (int i = 0; i < n; i++) {
if(i == 1) out[i] = 1;
else if(x[i] == x[i-1]) out[i] = out[i-1];
else out[i] = out[i-1] + 1;
}
return out;
}')
sys_time_print({
res1 = get_id(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})
sys_time_print({
res2 = get_id_c(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
sys_time_print({
res1 = get_id(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})
sys_time_print({
res2 = get_id_c(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
sys_time_print({
res1 = get_id(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})
sys_time_print({
res2 = get_id_c(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
sys_time_print({
res1 = get_id(c(rep(33L,1e7),rep(44L,4),rep(33L,3)))
})
sys_time_print({
res2 = get_id_c(c(rep(33L,1e7),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
sys_time_print({
res1 = get_id(c(rep(33L,1e8),rep(44L,4),rep(33L,3)))
})
sys_time_print({
res2 = get_id_c(c(rep(33L,1e8),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
sys_time_print({
res1 = get_id(c(rep(33L,1e9),rep(44L,4),rep(33L,3)))
})
sys_time_print({
res2 = get_id_c(c(rep(33L,1e9),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
1万:
![fdb6b92e3009d4421f13470079ff30c4.png](https://i-blog.csdnimg.cn/blog_migrate/0ce27ba020ed1100835c8ec5fdbdf9df.png)
10万:
![2a17f521467be565c78c0aecc97a0f2d.png](https://i-blog.csdnimg.cn/blog_migrate/78afe99bf1ba06cf2bd9e9ea084002f9.png)
100万:
![ea66334632cbb1970a52e807ec8c5946.png](https://i-blog.csdnimg.cn/blog_migrate/5eab83468d6d93be12db847cf93de42d.png)
1000万:
![a6abfe04f920e6ab41d85d59741383a0.png](https://i-blog.csdnimg.cn/blog_migrate/8233881222bb12af3841ee20fe4cf62e.png)
1亿:
![4d779e0a380b591c2d2dfb869959c94b.png](https://i-blog.csdnimg.cn/blog_migrate/1f61d0831b1ce59730bedc499e9cd5ba.png)
结论:还是Rcpp香。
三更:R代码提前设置向量长度,再比较。
library(pacman)
p_load(Rcpp,tidyfst)
get_id = function(x){
z = integer(length(x))
for (i in seq_along(x)) {
if(i == 1) z[i] = 1
else if(x[i] != x[i-1]) z[i] = z[i-1] + 1
else z[i] = z[i-1]
}
z
}
cppFunction('
IntegerVector get_id_c(IntegerVector x){
int n = x.size();
IntegerVector out(n);
for (int i = 0; i < n; i++) {
if(i == 1) out[i] = 1;
else if(x[i] == x[i-1]) out[i] = out[i-1];
else out[i] = out[i-1] + 1;
}
return out;
}')
sys_time_print({
res1 = get_id(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})
sys_time_print({
res2 = get_id_c(c(rep(33L,1e4),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
sys_time_print({
res1 = get_id(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})
sys_time_print({
res2 = get_id_c(c(rep(33L,1e5),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
sys_time_print({
res1 = get_id(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})
sys_time_print({
res2 = get_id_c(c(rep(33L,1e6),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
sys_time_print({
res1 = get_id(c(rep(33L,1e7),rep(44L,4),rep(33L,3)))
})
sys_time_print({
res2 = get_id_c(c(rep(33L,1e7),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
sys_time_print({
res1 = get_id(c(rep(33L,1e8),rep(44L,4),rep(33L,3)))
})
sys_time_print({
res2 = get_id_c(c(rep(33L,1e8),rep(44L,4),rep(33L,3)))
})
setequal(res1,res2)
![4894e519f56dd4213bdbecc41225dc2e.png](https://i-blog.csdnimg.cn/blog_migrate/534f0dde036c1420a1d71c5294db773f.jpeg)
![edb0594afa7fc44322474b5065c14432.png](https://i-blog.csdnimg.cn/blog_migrate/7da1363d123fd78910fc776411cbf3ce.png)
四更:对于这个任务来讲,data.table的rleid函数是最快的。R语言中的终极魔咒,能找到现成的千万不要自己写,总有巨佬在前头。不过直到10亿个整数才有难以忍受的差距。
![e755fb4390af7ff62568d4d2c10bb6ba.png](https://i-blog.csdnimg.cn/blog_migrate/69969649c805b5ccfc17b316d2cb2f10.png)
![6d5b340a5711ed4b4faa05f53290ebcc.png](https://i-blog.csdnimg.cn/blog_migrate/09c34c4fd18c02b3461426968578b945.png)