宝具滑 / JS简单实现决策树(ID3算法)

<script> 
// 文章: https://www.jianshu.com/p/2b50a98cd75c
    function DecisionTree(config) {
        if (typeof config == "object" && !Array.isArray(config)) this.training(config);
    };
    DecisionTree.prototype = {
        _predicates: {//分割函数
            '==': function (a, b) { return a == b },//针对非数字值的比较
            '>=': function (a, b) { return a >= b }//针对数值的比较
        },
        //统计属性值在数据集中的次数
        countUniqueValues(items, attr) {
            var counter = {};// 获取不同的结果值 与出现次数
            for (var i of items) {
                if (!counter[i[attr]]) counter[i[attr]] = 0;
                counter[i[attr]] += 1;
            }
            return counter;
        },
        //获取对象中值最大的Key  假设 counter={a:9,b:2} 得到 "a" 
        getMaxKey(counter) {
            var mostFrequentValue;
            for (var k in counter) {
                if (!mostFrequentValue) mostFrequentValue = k;
                if (counter[k] > counter[mostFrequentValue]) {
                    mostFrequentValue = k;
                }
            };
            return mostFrequentValue;
        },
        //寻找最频繁的特定属性值
        mostFrequentValue(items, attr) {
            return this.getMaxKey(this.countUniqueValues(items, attr));//计算值的出现数
        },
        //根据属性切割数据集 
        split(items, attr, predicate, pivot) {
            var data = {
                match: [],//适合的数据集
                notMatch: []//不适合的数据集
            }
            for (var item of items) { //遍历训练集  
                if (predicate(item[attr], pivot)) {//比较是否满足条件
                    data.match.push(item);
                } else {
                    data.notMatch.push(item);
                }
            };
            return data;
        },
        //计算熵
        entropy(items, attr) {
            var counter = this.countUniqueValues(items, attr);//计算值的出现数
            var p, entropy = 0;//H(S)=entropy=∑(P(Xi)(log2(P(Xi))))
            for (var i in counter) {//entropy+=-(P(Xi)(log2(P(Xi))))
                p = counter[i] / items.length;//P(Xi)概率值
                entropy += -p * Math.log2(p);
            }
            return entropy;
        },
        buildDecisionTree(config) {
            var trainingSet = config.trainingSet;//训练集
            var minItemsCount = config.minItemsCount;//训练集项数
            var categoryAttr = config.categoryAttr;//用于区分的类别属性
            var entropyThrehold = config.entropyThrehold;//熵阈值
            var maxTreeDepth = config.maxTreeDepth;//递归深度
            var ignoredAttributes = config.ignoredAttributes;//忽略的属性
            // 树最大深度为0 或训练集的大小 小于指定项数 终止树的构建过程
            if ((maxTreeDepth == 0) || (trainingSet.length <= minItemsCount)) {
                return { category: this.mostFrequentValue(trainingSet, categoryAttr) };
            }
            //初始计算 训练集的熵
            var initialEntropy = this.entropy(trainingSet, categoryAttr);//<===H(S)
            //训练集熵太小 终止
            if (initialEntropy <= entropyThrehold) {
                return { category: this.mostFrequentValue(trainingSet, categoryAttr) };
            }
            var alreadyChecked = [];//标识已经计算过了
            var bestSplit = { gain: 0 };//储存当前最佳的分割节点数据信息
            //遍历数据集
            for (var item of trainingSet) {
                // 遍历项中的所有属性
                for (var attr in item) {
                    //跳过区分属性与忽略属性
                    if ((attr == categoryAttr) || (ignoredAttributes.indexOf(attr) >= 0)) continue;
                    var pivot = item[attr];// 当前属性的值 
                    var predicateName = ((typeof pivot == 'number') ? '>=' : '=='); //根据数据类型选择判断条件
                    var attrPredPivot = attr + predicateName + pivot;
                    if (alreadyChecked.indexOf(attrPredPivot) >= 0) continue;//已经计算过则跳过
                    alreadyChecked.push(attrPredPivot);//记录
                    var predicate = this._predicates[predicateName];//匹配分割方式
                    var currSplit = this.split(trainingSet, attr, predicate, pivot);
                    var matchEntropy = this.entropy(currSplit.match, categoryAttr);//  H(match) 计算分割后合适的数据集的熵
                    var notMatchEntropy = this.entropy(currSplit.notMatch, categoryAttr);// H(on match) 计算分割后不合适的数据集的熵
                    //计算信息增益: 
                    // IG(A,S)=H(S)-(∑P(t)H(t))) 
                    // t为分裂的子集match(匹配),on match(不匹配)
                    // P(match)=match的长度/数据集的长度
                    // P(on match)=on match的长度/数据集的长度
                    var iGain = initialEntropy - ((matchEntropy * currSplit.match.length
                        + notMatchEntropy * currSplit.notMatch.length) / trainingSet.length);
                    //不断匹配最佳增益值对应的节点信息
                    if (iGain > bestSplit.gain) {
                        bestSplit = currSplit; 
                        bestSplit.predicateName = predicateName;
                        bestSplit.predicate = predicate;
                        bestSplit.attribute = attr;
                        bestSplit.pivot = pivot;
                        bestSplit.gain = iGain;
                    }
                }
            }

            // 找不到最优分割
            if (!bestSplit.gain) {
                return { category: this.mostFrequentValue(trainingSet, categoryAttr) };
            }
            // 递归绑定子树枝
            config.maxTreeDepth = maxTreeDepth - 1;//减小1深度
            config.trainingSet = bestSplit.match;//将切割 match 训练集作为下一节点的训练集
            var matchSubTree = this.buildDecisionTree(config);//递归匹配子树节点
            config.trainingSet = bestSplit.notMatch;//将切割 notMatch 训练集作为下一节点的训练集
            var notMatchSubTree = this.buildDecisionTree(config);//递归匹配子树节点 
            return  {
                attribute: bestSplit.attribute,
                predicate: bestSplit.predicate,
                predicateName: bestSplit.predicateName,
                pivot: bestSplit.pivot,
                match: matchSubTree,
                notMatch: notMatchSubTree,
                matchedCount: bestSplit.match.length,
                notMatchedCount: bestSplit.notMatch.length
            };
        },
        training(config) {
            this.root = this.buildDecisionTree({
                trainingSet: config.trainingSet,//训练集
                ignoredAttributes: config.ignoredAttributes || [],// 被忽略的属性比如:姓名、名称之类的
                categoryAttr: config.categoryAttr || 'category',//用于区分的类别属性
                minItemsCount: config.minItemsCount || 1,//最小项数量
                entropyThrehold: config.entropyThrehold || 0.01,//熵阈值
                maxTreeDepth: config.maxTreeDepth || 70//递归的最大深度 
            });
        },
        //预测 测试
        predict(data) {
            var attr, value, predicate, pivot;
            var tree = this.root;
            while (true) {
                if (tree.category) {
                    return tree.category;
                }
                attr = tree.attribute;
                value = data[attr];
                predicate = tree.predicate;
                pivot = tree.pivot;
                if (predicate(value, pivot)) {
                    tree = tree.match;
                } else {
                    tree = tree.notMatch;
                }
            }
        }
    };
</script>
<script>
    var data =
        [
            { "姓名": "余夏", "年龄": 29, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "见" },
            { "姓名": "豆豆", "年龄": 25, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "见" },
            { "姓名": "帅常荣", "年龄": 26, "长相": "帅", "体型": "胖", "收入": "高", 见面: "见" },
            { "姓名": "王涛", "年龄": 22, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "见" },
            { "姓名": "李东", "年龄": 23, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "见" },
            { "姓名": "王五五", "年龄": 23, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "见" },
            { "姓名": "王小涛", "年龄": 22, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "见" },
            { "姓名": "李缤", "年龄": 21, "长相": "帅", "体型": "胖", "收入": "高", 见面: "见" },
            { "姓名": "刘明", "年龄": 21, "长相": "帅", "体型": "胖", "收入": "低", 见面: "不见" },
            { "姓名": "红鹤", "年龄": 21, "长相": "不帅", "体型": "胖", "收入": "高", 见面: "不见" },
            { "姓名": "李理", "年龄": 32, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "不见" },
            { "姓名": "周州", "年龄": 31, "长相": "帅", "体型": "瘦", "收入": "高", 见面: "不见" },
            { "姓名": "李乐", "年龄": 27, "长相": "不帅", "体型": "胖", "收入": "高", 见面: "不见" },
            { "姓名": "韩明", "年龄": 24, "长相": "不帅", "体型": "瘦", "收入": "高", 见面: "不见" },
            { "姓名": "小吕", "年龄": 28, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "不见" },
            { "姓名": "李四", "年龄": 25, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "不见" },
            { "姓名": "王鹏", "年龄": 30, "长相": "帅", "体型": "瘦", "收入": "低", 见面: "不见" },
        ];
    var decisionTree = new DecisionTree();
    console.log("函数 countUniqueValues 测试:");
    console.log("   长相", decisionTree.countUniqueValues(data, "长相"));//测试
    console.log("   年龄", decisionTree.countUniqueValues(data, "年龄"));//测试
    console.log("   收入", decisionTree.countUniqueValues(data, "收入"));//测试
    console.log("函数 entropy 测试:");
    console.log("   长相", decisionTree.entropy(data, "长相"));//测试
    console.log("   年龄", decisionTree.entropy(data, "年龄"));//测试
    console.log("   收入", decisionTree.entropy(data, "收入"));//测试
    console.log("函数 mostFrequentValue 测试:");
    console.log("   年龄", decisionTree.mostFrequentValue(data, "年龄"));//测试 
    console.log("   长相", decisionTree.mostFrequentValue(data, "长相"));//测试 
    console.log("   收入", decisionTree.mostFrequentValue(data, "收入"));//测试 
    console.log("函数 split 测试:");
    console.log("   长相", decisionTree.split(data, "长相", (a, b) => { return a == b }, "不帅"));//测试
    console.log("   年龄", decisionTree.split(data, "年龄", (a, b) => { return a >= b }, 30));//测试
    console.log("   年龄", decisionTree.split(data, "年龄", (a, b) => { return a < b }, 25));//测试

    decisionTree.training({
        trainingSet: data,//训练集
        categoryAttr: '见面',//用于区分的类别属性 
        ignoredAttributes: ['姓名']//被忽略的属性
    });
    // 测试决策树与随机森林
    var comic = { "姓名": "刘建1", "年龄": 21, "长相": "帅", "体型": "瘦", "收入": "高" }; 
    console.log(comic,  decisionTree.predict(comic));
    comic = { "姓名": "刘建2", "年龄": 22, "长相": "不帅", "体型": "瘦", "收入": "高" }; 
    console.log(comic,  decisionTree.predict(comic));
    comic = { "姓名": "刘建3", "年龄": 27, "长相": "帅", "体型": "瘦", "收入": "低" }; 
    console.log(comic,  decisionTree.predict(comic));
    comic = { "姓名": "刘建4", "年龄": 30, "长相": "帅", "体型": "瘦", "收入": "高" }; 
    console.log(comic,  decisionTree.predict(comic));
    comic = { "姓名": "刘建5", "年龄": 29, "长相": "帅", "体型": "胖", "收入": "高" }; 
    console.log(comic,  decisionTree.predict(comic));
    comic = { "姓名": "刘建6", "年龄": 29, "长相": "帅", "体型": "胖", "收入": "低" }; 
    console.log(comic,  decisionTree.predict(comic));
    comic = { "姓名": "刘建7", "年龄": 40, "长相": "帅", "体型": "瘦", "收入": "低" }; 
    console.log(comic,  decisionTree.predict(comic));
</script>

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值