Python决策树实现 适用分类变量、连续变量、缺失值

按照周老师书上讲的用权重处理缺失值,发现问题很多。每当决策树选择包含缺失值的属性作为分割条件时,该属性包含缺失值的实例将被同时分配到每个子节点,这无疑将增加运算量。同时,各项退出条件的设置也不能再依据实例个数而应该是权重。总的来说,按照书上写的实现一遍,很多以前不懂的都一下子明白了。对sk learn中的决策树也有了更好的理解

CART算法对分类变量也是用的二叉树,这样不仅能提升运算速度,感觉这样能通过将分类变量进行组合,获得更好的泛化性能。

min_samples_split应该是不如min_samples_leaf的,但min_samples_leaf是先生成树,再剪枝叶节点样本小于某个值的,会耗费更多的运算,但min_samples_split太粗糙。比如min_samples_split=10时,会出现上一个节点12,划分后不同属性值都只有3、4个的情况。而用min_samples_leaf就会剪掉。我只实现了min_samples_split

没有写根据泛化性能剪枝的步骤。

 

 

这是不能处理缺失值的

import numpy as np
import pandas as pd
from pandas import DataFrame,Series

def Cal_entropy(frame,labelPosition):
    val_count = frame.loc[:, labelPosition].value_counts()
    pct_val = val_count / val_count.sum()
    log_pct = pct_val.map(np.log2)
    shannon_ent = -(pct_val * log_pct).sum()
    return shannon_ent


def BestFeatureToSplit(frame,infoFunction,labelPosition,featureClass):
    leastChaos = np.inf
    separatePoint = -1
    classColumnName = 'bestColumnName'
    for column in featureClass.index:
        if featureClass[column] == 0:
            valueCounts = frame[column].value_counts()
            if len(valueCounts) == 1:
                continue
            Chaos = (frame.groupby(column).\
            apply(infoFunction,labelPosition=labelPosition)*\
            (valueCounts/frame.shape[0])).sum()
            if Chaos < leastChaos:
                leastChaos = Chaos
                classColumnName = column            
        else:
            sortList = frame[column].drop_duplicates().sort_values()
            if len(sortList) == 1:
                continue
            sortList.index = np.arange(len(sortList))
            separatePointList = (round((sortList[i]+sortList[i+1])/2,4) 
                             for i in range(len(sortList)-1))
            for i in separatePointList:
                Compare = lambda x:'>%s'%i if x>i else '<=%s'%i
                classBy = frame[column].map(Compare)
                Chaos = (frame.groupby(classBy).\
                        apply(infoFunction,labelPosition=labelPosition)*\
                        (classBy.value_counts()/frame.shape[0])).sum()
                if Chaos<leastChaos:
                    leastChaos = Chaos
                    separatePoint = i
                    classColumnName = column
    return classColumnName,separatePoint


def ValueUnderCategory(frame,featureClass):
    valueUnderCategory = {}
    for column in featureClass.index:
        if featureClass[column] == 0:
            valueUnderCategory[column] = frame[column].unique()
    return valueUnderCategory
    
    
def CreateTree(frame, labelPosition,featureClass, infoFunction, valueUnderCategory):
    if len(frame.loc[:, labelPosition].unique()) == 1:
        return frame[labelPosition].iloc[0]
    if len(frame) <=5 :
        return frame[labelPosition].value_counts().index[0]
    classColumnName, separatePoint = BestFeatureToSplit(frame,
                                                 infoFunction, labelPosition, featureClass)    
    if classColumnName == 'bestColumnName':
        return frame[labelPosition].value_counts().index[0]  #从逻辑上说,应该先判断所有退出条件,再执行最佳路径选择函数    
    mytree = {classColumnName: {}}                           #但这样会增加运算量,而路径选择函数里也需要调用类似的部分
    if featureClass[classColumnName] == 0:                   #所以把这个退出条件写成了根据路径选择结果判断
        #del (featureClass[classColumnName])
        for value in valueUnderCategory[classColumnName]:
            subfeatureClass = featureClass.drop(index=classColumnName)
            subframe = frame[frame.loc[:, classColumnName] == value]
            if len(subframe) == 0:
                mytree[classColumnName][value] = frame[labelPosition].value_counts().index[0]
            else:
                mytree[classColumnName][value] = CreateTree(subframe, 
                  labelPosition, subfeatureClass, infoFunction, valueUnderCategory)
    else:
        subframeUnder = frame[frame.loc[:, classColumnName]<=separatePoint]
        subframeAbove = frame[frame.loc[:, classColumnName]>separatePoint]
        mytree[classColumnName]['<=%s'%separatePoint] = CreateTree(subframeUnder, 
                  labelPosition, featureClass, infoFunction, valueUnderCategory)
        mytree[classColumnName]['>%s'%separatePoint] = CreateTree(subframeAbove, 
                  labelPosition, featureClass, infoFunction, valueUnderCategory)
    return mytree

能处理缺失值的

import numpy as np
import pandas as pd
from pandas import DataFrame,Series

def DataPrepare(frame):
    frameCut = frame.loc[:,['Pclass','Age','SibSp','Parch','Fare','Embarked',
                 'Sex','Survived']]
    fun_fix_Cabin = lambda x: 'Z' if x is np.nan else x[0]
    frameCut['W'] = 1
    frameCut['Cabin'] = frame['Cabin'].map(fun_fix_Cabin)
    return frameCut
    
def Cal_entropy(frame,labelPosition):
    val_count = frame.groupby(labelPosition)['W'].sum()
    pct_val = val_count / val_count.sum()
    log_pct = pct_val.map(np.log2)
    shannon_ent = -(pct_val * log_pct).sum()
    return shannon_ent

def BestFeatureToSplit(frame,infoFunction,labelPosition,featureClass):
    maxGain = -1
    separatePoint = -1
    classColumnName = 'bestColumnName'
    weightListFinal = Series([])
    for column in featureClass.index:
        if featureClass[column] == 0:
            valueWeight = frame.groupby(column)['W'].sum()
            sumOfWeight = valueWeight.sum()
            weightList = valueWeight/sumOfWeight            
            nullRatio = frame[frame[column].notnull()]['W'].sum() / \
                        frame['W'].sum()
            
            gain = (Cal_entropy(frame,labelPosition)-(frame.groupby(column).\
            apply(infoFunction,labelPosition=labelPosition)*\
            weightList).sum())*nullRatio
            if gain > maxGain:
                maxGain = gain
                classColumnName = column  
                weightListFinal = weightList
        else:
            sortList = frame[column].drop_duplicates().dropna().sort_values()
            if len(sortList) in (1,0):
                continue
            sortList.index = np.arange(len(sortList))
            separatePointList = (round((sortList[i]+sortList[i+1])/2,4) 
                             for i in range(len(sortList)-1))
            frameNotNull = frame[frame[column].notnull()]
            nullRatio = frameNotNull['W'].sum() / frame['W'].sum()                        
            sumOfWeight = frameNotNull['W'].sum()
            for i in separatePointList:
                Compare = lambda x:'>%s'%i if x>i else '<%s'%i
                classBy = frameNotNull[column].map(Compare)
                valueWeight = frameNotNull.groupby(classBy)['W'].sum()
                weightList = valueWeight / sumOfWeight
                gain = (Cal_entropy(frame,labelPosition)-(frameNotNull.groupby(classBy).\
                        apply(infoFunction,labelPosition=labelPosition)*\
                        weightList).sum())*nullRatio
                if gain > maxGain:                   
                    maxGain = gain
                    separatePoint = i
                    classColumnName = column
                    weightListFinal = weightList
    return classColumnName,separatePoint,weightListFinal


def ValueUnderCategory(frame,featureClass):
    valueUnderCategory = {}
    for column in featureClass.index:
        if featureClass[column] == 0:
            valueUnderCategory[column] = frame[column].unique()
    return valueUnderCategory
    
    
def CreateTree(frame, labelPosition,featureClass, infoFunction, valueUnderCategory):
    weightLabelList = frame.groupby(labelPosition)['W'].sum()
    if weightLabelList.max()/weightLabelList.sum() > 0.95:
        return weightLabelList.idxmax()                    
    if frame['W'].sum() <8 :
        return frame[labelPosition].value_counts().index[0]     
    classColumnName, separatePoint, weightListFinal = BestFeatureToSplit(frame,
                                                 infoFunction, labelPosition, featureClass)

    if classColumnName == 'bestColumnName':
        return frame[labelPosition].value_counts().index[0]
    #print(classColumnName,separatePoint,weightListFinal)      
    mytree = {classColumnName: {}}  
    #print(classColumnName)                                   
    if featureClass[classColumnName] == 0:                   
        #del (featureClass[classColumnName])
        for value in valueUnderCategory[classColumnName]:
            subfeatureClass = featureClass.drop(index=classColumnName)
            subframe_1 = frame[frame.loc[:, classColumnName] == value]           
            if len(subframe_1) == 0:                               
                mytree[classColumnName][value] = frame[labelPosition].value_counts().index[0]
            else:                
                subframe_2 = frame.loc[frame[classColumnName].isnull()]
                subframe_2['W'] = subframe_2['W']*weightListFinal[value]
                subframe = pd.concat([subframe_1,subframe_2])
                mytree[classColumnName][value] = CreateTree(subframe, 
                  labelPosition, subfeatureClass, infoFunction, valueUnderCategory)
    else:
        #print(separatePoint)
        subframeUnder_1 = frame[frame[classColumnName]<separatePoint]
        subframeAbove_1 = frame[frame[classColumnName]>separatePoint]
        subframeUnder_2 = frame.loc[frame[classColumnName].isnull()]
        subframeUnder_2['W']=subframeUnder_2['W']*weightListFinal['<%s'%separatePoint]
        subframeAbove_2 = frame.loc[frame[classColumnName].isnull()]
        subframeAbove_2['W']=subframeAbove_2['W']*weightListFinal['>%s'%separatePoint]
        subframeUnder = pd.concat([subframeUnder_1, subframeUnder_2])
        subframeAbove = pd.concat([subframeAbove_1, subframeAbove_2])
        mytree[classColumnName]['<%s'%separatePoint] = CreateTree(subframeUnder, 
                  labelPosition, featureClass, infoFunction, valueUnderCategory)
        mytree[classColumnName]['>%s'%separatePoint] = CreateTree(subframeAbove, 
                  labelPosition, featureClass, infoFunction, valueUnderCategory)
    return mytree

调用和结果

path = 'E:/titanic/'
train = pd.read_csv(path + 'train.csv',index_col=[0])
test = pd.read_csv(path + 'test.csv',index_col=[0])
trainCut = DataPrepare(train)
featureClass = Series([0,1,1,1,1,0,0,0],index=['Pclass','Age','SibSp','Parch','Fare','Embarked','Sex','Cabin'])
valueUnderCategory = ValueUnderCategory(trainCut,featureClass)
BestFeatureToSplit(trainCut,Cal_entropy,'Survived',featureClass)
CreateTree(trainCut, 'Survived',featureClass, Cal_entropy, valueUnderCategory)
{'Sex': {'female': {'Pclass': {1: 1,
    2: {'Age': {'<56.0': {'Fare': {'<26.125': {'Fare': {'<12.825': 1,
          '>12.825': {'Age': {'<23.5': 1,
            '>23.5': {'Age': {'<27.5': 1,
              '>27.5': {'Age': {'<37.0': 1, '>37.0': 1}}}}}}}},
        '>26.125': 1}},
      '>56.0': 0}},
    3: {'Fare': {'<23.35': {'Embarked': {'C': {'Fare': {'<15.4938': {'Fare': {'<13.9354': 1,
            '>13.9354': 0}},
          '>15.4938': 1}},
        'Q': {'Parch': {'<0.5': {'Fare': {'<7.6812': 0,
            '>7.6812': {'Age': {'<20.0': {'Age': {'<15.5': 1, '>15.5': 1}},
              '>20.0': 1}}}},
          '>0.5': 0}},
        'S': {'Age': {'<36.5': {'Age': {'<32.0': {'Fare': {'<7.7625': 1,
              '>7.7625': {'Fare': {'<10.825': {'Fare': {'<10.1521': {'Parch': {'<0.5': {'Fare': {'<9.8396': {'Fare': {'<8.7666': {'Fare': {'<8.6729': {'Fare': {'<7.9875': {'Age': {'<18.5': 0,
                              '>18.5': 1}},
                            '>7.9875': 0}},
                          '>8.6729': 1}},
                        '>8.7666': 0}},
                      '>9.8396': 1}},
                    '>0.5': 1}},
                  '>10.1521': 0}},
                '>10.825': {'Fare': {'<17.25': {'SibSp': {'<0.5': 1,
                    '>0.5': 1}},
                  '>17.25': 0}}}}}},
            '>32.0': 1}},
          '>36.5': 0}},
        nan: 1}},
      '>23.35': {'Parch': {'<0.5': 1,
        '>0.5': {'Fare': {'<31.3312': 0, '>31.3312': 0}}}}}}}},
  'male': {'Cabin': {'A': {'Fare': {'<54.4646': {'Fare': {'<37.8125': 1,
        '>37.8125': 0}},
      '>54.4646': 1}},
    'B': {'Fare': {'<85.1396': {'Fare': {'<30.1': 0, '>30.1': 0}},
      '>85.1396': 1}},
    'C': {'Age': {'<17.5': 1,
      '>17.5': {'Fare': {'<98.2125': {'Age': {'<28.5': 1,
          '>28.5': {'Age': {'<47.5': {'Age': {'<37.5': 0, '>37.5': 0}},
            '>47.5': 1}}}},
        '>98.2125': 0}}}},
    'D': {'Fare': {'<77.0084': {'Fare': {'<58.2292': 0, '>58.2292': 1}},
      '>77.0084': 0}},
    'E': {'Age': {'<44.0': 1, '>44.0': 0}},
    'F': 0,
    'G': 0,
    'T': 0,
    'Z': {'Age': {'<9.5': {'SibSp': {'<2.5': {'Parch': {'<0.5': 0, '>0.5': 1}},
        '>2.5': {'Parch': {'<1.5': 0, '>1.5': 0}}}},
      '>9.5': {'Fare': {'<54.2479': {'Embarked': {'C': {'Fare': {'<27.1354': {'Fare': {'<15.1479': {'Age': {'<29.5': {'Fare': {'<14.1584': {'Fare': {'<9.9771': {'SibSp': {'<0.5': {'Fare': {'<5.6188': 0,
                        '>5.6188': {'Fare': {'<8.3042': {'Fare': {'<7.5625': {'Age': {'<22.75': 0,
                              '>22.75': 0}},
                            '>7.5625': 0}},
                          '>8.3042': 0}}}},
                      '>0.5': 0}},
                    '>9.9771': 1}},
                  '>14.1584': 0}},
                '>29.5': 0}},
              '>15.1479': 1}},
            '>27.1354': 0}},
          'Q': {'SibSp': {'<1.5': {'Fare': {'<7.7458': 0,
              '>7.7458': {'Age': {'<30.0': 0, '>30.0': 0}}}},
            '>1.5': 1}},
          'S': {'Parch': {'<0.5': {'Fare': {'<26.275': {'Fare': {'<13.25': {'Fare': {'<7.9104': {'Fare': {'<7.8646': {'Age': {'<32.5': {'Age': {'<23.5': 0,
                        '>23.5': {'Fare': {'<7.0104': 0,
                          '>7.0104': {'Age': {'<31.5': {'Age': {'<27.5': {'Age': {'<26.5': {'Fare': {'<7.0958': 0,
                                  '>7.0958': 0}},
                                '>26.5': 0}},
                              '>27.5': 0}},
                            '>31.5': 0}}}}}},
                      '>32.5': 0}},
                    '>7.8646': 0}},
                  '>7.9104': {'Fare': {'<7.9875': {'SibSp': {'<1.5': 0,
                      '>1.5': 0}},
                    '>7.9875': {'Age': {'<19.5': {'Fare': {'<8.1354': 0,
                        '>8.1354': {'Age': {'<18.5': 0, '>18.5': 0}}}},
                      '>19.5': {'Age': {'<26.5': 0,
                        '>26.5': {'Fare': {'<8.0812': 0,
                          '>8.0812': {'Age': {'<31.5': {'Pclass': {1: 0,
                              2: 0,
                              3: 0}},
                            '>31.5': {'Age': {'<40.5': 0,
                              '>40.5': 0}}}}}}}}}}}}}},
                '>13.25': 0}},
              '>26.275': 0}},
            '>0.5': 0}},
          nan: 0}},
        '>54.2479': {'Fare': {'<58.9375': 1,
          '>58.9375': {'Fare': {'<107.9104': 0, '>107.9104': 0}}}}}}}}}}}}

决策树的使用还是有问题,看西瓜书决策树这一章,讲了在生成树时如何处理缺失数据,但没讲如何用这样生成的树去分类包含缺失值的实例。

这是不支持预测缺失值的

def Predict(ser,tree,featureClass):    
    for feature in tree:
        if featureClass[feature] == 0:            
            categoryValue = tree[feature][ser[feature]]   
            if not isinstance(categoryValue,dict):
                return categoryValue
            return Predict(ser, categoryValue, featureClass)
        classValue = float(list(tree[feature].keys())[0][1:])
        if ser[feature] < classValue:
            if not isinstance(tree[feature]['<%s'%classValue],dict):
                return tree[feature]['<%s'%classValue]
            return Predict(ser, tree[feature]['<%s'%classValue], featureClass)
        if not isinstance(tree[feature]['>%s'%classValue],dict):
            return tree[feature]['>%s'%classValue]
        return Predict(ser, tree[feature]['>%s'%classValue], featureClass)

def PredictMore(frame,tree,featureClass):
    result = Series(index=frame.index)
    for i in frame.index:
        ser = frame.loc[i]
        predictValue = Predict(ser,tree,featureClass)
        result[i] = predictValue
    return result

 

  • 2
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值