Julia: 决策树与随机森林连续值的处理

54 篇文章 1 订阅
18 篇文章 0 订阅
本文详细解析了如何在周志华的《机器学习》中处理决策树中的连续值数据,以西瓜的密度和含糖量为例,通过排序、取平均划分点和计算增益确定最佳分割点。具体展示了如何计算17个西瓜密度值的划分点及其增益,最终确定密度的最佳划分点为0.381。
摘要由CSDN通过智能技术生成

在周志华《机器学习》西瓜书,有关于决策树的连续值的处理的描述,并用西瓜的密度和含糖量这两个连续属性来举例。

那么,下面以17个西瓜的密度(features)为例:

[0.697,0.774,0.634,0.608,0.556,0.403,0.481,0.437,0.666,0.243,0.245,0.343,0.639,0.657,0.360,0.593,0.719]

那么如何找到相应的划分点?

1、对17个连续值X排序后,再对序列的前一个和后一个取平均,得到16个划分点 split_X

利用二分法,由sort_X到split_X。当然n个连续值,不必然得到n-1个划分点,因为有重复值的存在。

先对X排序得到sort_X:

[0.243, 0.245, 0.343, 0.36, 0.403, 0.437, 0.481, 0.556, 0.593, 0.608, 0.634, 0.639, 0.657, 0.666, 0.697, 0.719, 0.774]  # sort_X

定划分点,得到split_X:

[0.244, 0.29400000000000004, 0.35150000000000003, 0.3815, 0.42000000000000004, 0.45899999999999996, 0.5185, 0.5745, 0.6005, 0.621, 0.6365000000000001, 0.648, 0.6615, 0.6815, 0.708, 0.7464999999999999]

2、针对16个划分点,求到每个划分点的增益,并得到gains向量,找到gains中最大的增益点,即为其对应的最佳划分点:
一般而言,信息增益越大,则意味着使用“属性”或“划分点”所获得的“纯度提升”越大。相同的情况下,纯度越高,模型越好。
gains:

[0.05682352941176472, 0.11847797178090891, 0.18663565267767523, 0.26293671403544794, 0.09399614386760902, 0.03069956878979646, 0.004082532221190538, 0.0027244389091765076, 0.0027244389091765076, 0.004082532221190482, 0.030699568789796516, 0.006543942807450409, 0.0012672425232922446, 0.02458344666805945, 0.0008309129573794705, 0.0674593804343554, 0.0]

可见,以上面的增益最大的值为0.2629,对应的划分点为0.381.所以密度这个连续性的数据的最佳划分点就是0.381.(密度与这个值进行比较)

3、相关计算划分点和增益程序如下:

# 考虑了X中连续值可能有多个相同的情况
function get_gains(X,y,splits_X) :: Vector{Float64}
    @assert length(X) ==length(y) 
    gains = zeros(length(splits_X)) ## 对应split_x的增益向量,并初始化
    for i in 1:length(splits_X)
        split_ = splits_X[i]
        greater_num = 0 ## 大于split的数量,个数
        greater_yes_num = 0 ## 大于split值,且为正例的数量,个数
        greater_no_num = 0 ## 大于split值,且为反例的数量,个数
        less_num = 0 ## 小于split的数量,个数
        less_yes_num = 0 ## 小于split的数量,且为正例的数量,个数
        less_no_nun = 0 ## 小于split的数量,且为反例的数量,个数
        for j in 1:length(X)
            _x = X[j]
            if _x > split_
                greater_num += 1
                if y[j] == 1 ## 对应的标签是正例
                    greater_yes_num +=1
                else
                    greater_no_num +=1
                end
            else
                less_num += 1
                if y[j] == 1
                    less_yes_num +=1
                else
                    less_no_nun +=1
                end
            end
    
        end
        if greater_num > 0
            if  greater_yes_num  == 0
                entrop_greater = - (greater_no_num/greater_num)*log2(greater_no_num/greater_num)
            elseif greater_no_num == 0
                entrop_greater = - (greater_yes_num/greater_num)*log2(greater_yes_num/greater_num)
            else
                entrop_greater = -(greater_yes_num/greater_num)*log2(greater_yes_num/greater_num) - (greater_no_num/greater_num)*log2(greater_no_num/greater_num)
            end
        else
            entrop_greater = 0
        end
        if less_num > 0
            if less_yes_num ==0  
                entrop_less = - (less_no_nun/less_num)*log2(less_no_nun/less_num)
            elseif less_no_nun ==0 
                entrop_less = -(less_yes_num/less_num)*log2(less_yes_num/less_num)
            else
                entrop_less = -(less_yes_num/less_num)*log2(less_yes_num/less_num) - (less_no_nun/less_num)*log2(less_no_nun/less_num)
            end
        else
            entrop_less =0
        end
        gains[i] = entrop_total - greater_num/length(X)*entrop_greater - less_num/length(X)*entrop_less
     
    end
    return gains
end

X   = [0.697,0.774,0.634,0.608,0.556,0.403,0.481,0.437,0.666,0.243,0.245,0.343,0.639,0.657,0.360,0.593,0.719] # 这里指密度值序列 -> features
y   = [true,true,true,true,true,true,true,true,false,false,false,false,false,false,false,false,false] # 对应密度值序列的标签  ->labels ; 其中: true:正例(好瓜),false:反例
entrop_total = 0.998  ## 17个瓜中,8个正例,9个负例, 由 -(8/17*log2(8/17) +9/17*log2(9/17))计算得来,为“密度”的信息熵,即纯度
sort_X = sort(X) # 对X属性的连续值进行排序
splits_X = (sort_X[1:end-1] .+ sort_X[2:end])/2 # 根据前后值中点值,建立划分点集合,vector{Float64}
gains = get_gains(X,y,splits_X)
println("gains : $(gains)") #输出gains
println("gains max value : $(maximum(gains)) ")

输出:

gains : [0.05682352941176472, 0.11847797178090891, 0.18663565267767523, 0.26293671403544794, 0.09399614386760902, 0.03069956878979646, 0.004082532221190538, 0.0027244389091765076, 0.0027244389091765076, 0.004082532221190482, 0.030699568789796516, 0.006543942807450409, 0.0012672425232922446, 0.02458344666805945, 0.0008309129573794705, 0.0674593804343554]
gains max value : 0.26293671403544794  
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值