21.8.23_决策树

reference:
1.https://blog.csdn.net/jiaoyangwm/article/details/79525237
2.https://www.bilibili.com/video/BV1CB4y1c7UQ?p=22

1. 决策树

原理: 将信息分类,使得沿着路径特征不断聚合

在这里插入图片描述
比如用这张图,根据深度由根到叶到子叶可以将分类目标分别设置为年龄,学历,经历,性别

到深的叶节点,每一个点上的特征有四种且每种只有一个

对于待预测样本,根据其每一个特征的值,选择对应的子表,逐一匹配,直到找到与之完全匹配的叶级子表,用该子表中样本的输出,通过平均(回归)或者投票(分类)为待预测样本提供输出。

一些说明:
随着子表的划分,信息熵(信息的混乱程度)越来越小,信息越来越纯,数据越来越有序。

方法:

import sklearn.tree as st

# 创建决策树回归器模型  决策树的最大深度为num
model = st.DecisionTreeRegressor(max_depth=num)
# 训练模型  
# train_x: 二维数组样本数据
# train_y: 训练集中对应每行样本的结果
model.fit(train_x, train_y)
# 测试模型
pred_test_y = model.predict(test_x)


import sklearn.datasets as sd
import sklearn.utils as su
import sklearn.tree as st
import sklearn.metrics as sm

#下载数据集
boston=sd.load_boston()
#打乱原数据集的排序
#random_state:随机种子(可认为是标签),相同的随机种子对应的排序一致
x,y=su.shuffle(boston.data,boston.target,random_state=7)
#分割,80%的用于训练
train_size=int(len(x)*0.8)
train_x,test_x,train_y,test_y=x[:train_size],x[train_size:],y[:train_size],y[train_size:]
#模型
num=len(boston.feature_names)
model=st.DecisionTreeRegressor(max_depth=num)
model.fit(train_x,train_y)
#训练
pred_y=model.predict(test_x)
print(sm.r2_score(test_y,pred_y))

2.手写版

能力有限,改了两天
基于 https://www.bilibili.com/video/BV13V411s7u2
http://cs229.stanford.edu/notes2021spring/notes2021spring/Decision_Trees_CS229.pdf

分类标准:
在这里插入图片描述
L(R1)以及L(R2)用这两个算
Pi指占整体的比例
在这里插入图片描述

import numpy as np
import matplotlib.pyplot as plt


np.random.seed(12)
num=50

x1=np.random.randint(2,5,size=(num,2))
x2=np.random.randint(4,6,size=(num,2))
#vstack上下叠加,hstack左右叠加
#x是点集,y是标签
x=np.vstack((x1,x2))
y=np.hstack((np.zeros(num),np.ones(num)))

plt.figure(figsize=(12,8))
plt.scatter(x[:,0],x[:,1],c=y,alpha=.4)
plt.show()
#Gini impurity:sum(Pi*(1-Pi)),值越小表示分割的效果越好

def giniscore(x,y):
    y_=[]
    for i in y:
        y_.append(i)
    #把y里面的所有标签都做成keys,value为0
    dict_new=dict.fromkeys(y_,0)
    #选出y是0,y是1for key in dict_new.keys():
       dict_new[key]=np.array([elem for idx,elem in enumerate(x) if (y[idx]==key).all()])
#giniscore
    num_sum=0
    for i in dict_new.values():
        #求比例
        #value就只有01
        tem=(i.shape[0]/x.shape[0])**2
        num_sum=tem+num_sum
    return 1-num_sum

def spliting(x,y,gini_base=1):
    dict_={}
    #print(x)
    #print('---'*50)
    for j in range(x.shape[1]):
        range_=sorted(set(x[:,j]))
        #j当作坐标,i为值
        for i in np.arange(range_[0],range_[-1],0.5):
            left_x = np.array([elem for idx, elem in enumerate(x) if x[idx, j] < i])
            right_x = np.array([elem for idx, elem in enumerate(x) if x[idx, j] >= i])
            left_y = np.array([y[idx] for idx, elem in enumerate(x) if x[idx, j] < i])
            right_y = np.array([y[idx] for idx, elem in enumerate(x) if x[idx, j] >= i])

            left_ratio, right_ratio = left_y.shape[0] / y.shape[0], right_y.shape[0] / y.shape[0]
            res = left_ratio * giniscore(left_x, left_y) + right_ratio * giniscore(right_x, right_y)
            dict_[(j, i)] = res

   #排序内容是每个项中的第二个元素,也就是item[1],比如【【12】,【34】,【56】】中的246
    res=dict(sorted(dict_.items(),key=lambda item: item[1]))
    keys=list(res.keys())
    primary_key=keys[0]

    gini_c=list(res.values())[0]
    gini_gain=gini_base-gini_c
    gini_base=gini_c

    if gini_gain>0:
        j,i=primary_key[0],primary_key[1]
        left_x = np.array([elem for idx, elem in enumerate(x) if x[idx, j] < i])
        right_x = np.array([elem for idx, elem in enumerate(x) if x[idx, j] >=i])
        left_y = np.array([y[idx] for idx, elem in enumerate(x) if x[idx, j] < i])
        right_y = np.array([y[idx] for idx, elem in enumerate(x) if x[idx, j] >= i])

        return primary_key,gini_base,left_x,left_y,right_x,right_y
    else:
        return None,None,None,None,None,None


def grow_tree(x,y,nodes=[]):
    primary_key, gini_base, left_x, left_y, right_x, right_y=spliting(x,y)
    nodes.append(primary_key)
    if len(left_x.flatten())>=2 & len(left_y)>=2:
        print('222')
        grow_tree(left_x,left_y,nodes)
    print(len(right_x.flatten()))
    print(len(right_y))
    if len(right_x.flatten()) >=2 & len(right_y) >=2:
        print('222')
        grow_tree(right_x,right_y,nodes)

node=[]
grow_tree(x,y,node)
print(node)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值