分类树/装袋法/随机森林算法的R语言实现

本文介绍了如何使用R语言实现分类树、装袋法(bagging)和随机森林(random forest)算法。作者详细讲解了分类树的基本知识,包括pred、gini、splitrule等函数,并探讨了装袋法与随机森林的基本原理和实现,最后进行了性能测试。
摘要由CSDN通过智能技术生成

原文首发于简书于[2018.06.12]


本文是我自己动手用R语言写的实现分类树的代码,以及在此基础上写的袋装法(bagging)和随机森林(random forest)的算法实现。全文的结构是:

  • 分类树
    • 基本知识
    • pred
    • gini
    • splitrule
    • splitrule_best
    • splitrule_random
    • splitting
    • buildTree
    • predict
  • 装袋法与随机森林
    • 基本知识
    • bagging
    • predict_ensemble
    • 性能测试
  • 写在后面

全部的代码如下:

## x和y为自变量(矩阵)和因变量(列向量)
ylevels = levels(y)
nlevels = length(ylevels)
n = length(y)
k = dim(x)[2]
pred = function(suby) {
   
  # majority vote 所占比例最大的值为预测值
  vote = rep(0,nlevels)
  for (i in 1:nlevels) { vote[i] = sum(suby==ylevels[i]) }
  return(ylevels[which.max(vote)])
}

gini = function(suby,summary = FALSE) { 
  # 给定一列因变量,计算它的基尼系数
  if (!summary) suby = as.vector(summary(suby))
  obs = sum(suby)
  return(   1-(drop(crossprod(suby)))/(obs*obs)  )
}

splitrule = function(subx,suby) { 
  #这里subx和suby都是向量,给定一个自变量,求出它的最优划分条件
  subylen = length(suby)
  xvalues  = sort(unique(subx))  
  if (length(xvalues)>1) { #只有在自变量的取值不是完全相同时(有不同的x),才能够对x划分
    cutpoint = ( xvalues[1:(length(xvalues)-1)] + xvalues[2:length(xvalues)] )/2
    minimpurity = 1  #初始启动值
    for ( i in 1:length(cutpoint) ) {
      lefty = suby[subx>=cutpoint[i]]
      righty = suby[subx<cutpoint[i]]
      impurity = ( gini(lefty)*length(lefty) + gini(righty)*length(righty) )/subylen
      # 如果一个点是父节点,它的不纯度是它的两个子节点的基尼系数加权和;
      # 如果是叶节点,则不纯度为它本身的基尼系数
      if (impurity<minimpurity) {
        minimpurity = impurity
        splitpoint = cutpoint[i]
      }
    }
  }else {
    splitpoint = xvalues
    minimpurity = gini(suby)
  }
  return(c(splitpoint,minimpurity))
}

splitrule_best = function(subx,suby) {
   
  # 这是对上面splitrule原有版本的改进。自变量按从小到大排列,每次划分时,把落入到左边的观测记录下来
  # 随着划分节点不断增大,落入左节点的数据越来越多,在上一次的基础上进行累加就行。

  suby_length = length(suby)
  xvalues = sort(unique(subx))
  if (length(xvalues)>1) {
    intervals = cut(subx,xvalues,right = FALSE) #区间左闭右开
    suby_splited = matrix(unlist(by(suby,intervals,summary)),ncol = nlevels, byrow = TRUE)
    suby_splited_left = matrix(apply(suby_splited,2,cumsum),ncol = nlevels)
    suby_splited_right = sweep(-suby_splited_left,2,as.vector(summary(suby)),FUN = "+") 
    suby_splited_left_obs = apply(suby_splited_left,1,sum)
    suby_splited_right_obs = suby_length - suby_splited_left_obs
    impurity = NULL
    for (i in 1:(length(xvalues)-1) ) {
      impurity = c(impurity,
        (gini(suby_splited_left[i,],summary = TRUE)*suby_splited_left_obs[i]
         + gini(suby_splited_right[i,],summary = TRUE)*suby_splited_right_obs[i])/suby_length)
    }
    minimpurity = min(impurity)
    splitpoint = xvalues[which.min(impurity)]
  } else {
    splitpoint = xvalues
    minimpurity = gini(suby)
  }
  return(c(splitpoint,minimpurity))
}

splitrule_random = function(subx,suby) {
   
  # splitrule的一个变形。较为极端的方法,不排序,不取唯一值,任意抽取一个数作划分节点。
  suby_length = length(suby)
  subx_withoutmax = subx[subx!=max(subx)]
  if (length(subx_withoutmax)>0) {
    splitpoint = subx_withoutmax[sample(length(subx_withoutmax),1)]
    suby_splited_left = suby[subx<=splitpoint]
    suby_splited_right = suby[subx>splitpoint]
    impurity = (gini(suby_splited_left)*length(suby_splited_left) 
         + gini(suby_splited_right)*length(suby_splited_right))/suby_length
  } else{
    splitpoint = 0
    impurity = 1
  }
  return(c(splitpoint,impurity))
}

splitting = function(subx,suby,split,rf) {
   
  # subx是一个矩阵,suby是列向量。给定自变量矩阵,返还最优划分变量和相应最优划分条件
  if (!rf) chosen_variable = 1:k
  if (rf == TRUE) chosen_variable = sample(1:k,round(sqrt(k)))
  if (split == "best")  temp = apply(subx[,chosen_variable],2,splitrule_best,suby=suby) 
  if (split == "random") temp = apply(subx[,chosen_variable],2,splitrule_random,suby=suby) 
  splitpoint = temp[1,]
  minimpurity = temp[2,]
  splitvariable = chosen_variable[which.min(minimpurity)] #确定第几个变量是最优划分变量
  splitpoint = splitpoint[which.min(minimpurity)]
  minimpurity = min(minimpurity)
  return(c(splitvariable,splitpoint,minimpurity))
}
buildTREE = function(x,y,split = "best",rf = FALSE) {
   
  TREE = NULL
  index = 1:n 
  tree = as.data.frame(list(leftnode=0,rightnode=0,splitvariable=0,splitpoint=0,obs=n,
                            pred=levels(y)[1],leaf=TRUE,begin=1,end=n)) 
  cur = 1 #当前正在分析的节点编号,它要追赶nodes,每次循环过后增加1。
  nodes = 1 #当前这棵树的总节点数量
  while ((cur<=nodes) ) {
    beginC = tree$begin[cur]; endC = tree$end[cur]
    indexCurrent = index[beginC:endC]
    subx = x[indexCurrent,]
    suby = y[indexCurr
  • 4
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值