决策树ID3算法-matlab实现

ID3_decision_tree.m

%% 使用ID3决策树算法预测销量高低
clear ;

%% 数据预处理
disp('正在进行数据预处理...');
[matrix,attributes_label,attributes] =  id3_preprocess();

%% 构造ID3决策树,其中id3()为自定义函数
disp('数据预处理完成,正在进行构造树...');
tree = id3(matrix,attributes_label,attributes);

%% 打印并画决策树
[nodeids,nodevalues] = print_tree(tree);
tree_plot(nodeids,nodevalues);

disp('ID3算法构建决策树完成!');

id3_preprocess.m

function [ matrix,attributes,activeAttributes ] = id3_preprocess(  )
%% ID3算法数据预处理,把字符串转换为0,1编码

% 输出参数:
% matrix: 转换后的0,1矩阵;
% attributes: 属性和Label;
% activeAttributes : 属性向量,全1;

%% 读取数据
txt = {  '序号'    '天气'    '是否周末'    '是否有促销'    '销量'
        ''        '坏'      '是'          '是'            '高'  
        ''        '坏'      '是'          '是'            '高'  
        ''        '坏'      '是'          '是'            '高'  
        ''        '坏'      '否'          '是'            '高'  
        ''        '坏'      '是'          '是'            '高'  
        ''        '坏'      '否'          '是'            '高'  
        ''        '坏'      '是'          '否'            '高'  
        ''        '好'      '是'          '是'            '高'  
        ''        '好'      '是'          '否'            '高'  
        ''        '好'      '是'          '是'            '高'  
        ''        '好'      '是'          '是'            '高'  
        ''        '好'      '是'          '是'            '高'  
        ''        '好'      '是'          '是'            '高'  
        ''        '坏'      '是'          '是'            '低'  
        ''        '好'      '否'          '是'            '高'  
        ''        '好'      '否'          '是'            '高'  
        ''        '好'      '否'          '是'            '高'  
        ''        '好'      '否'          '是'            '高'  
        ''        '好'      '否'          '否'            '高'  
        ''        '坏'      '否'          '否'            '低'  
        ''        '坏'      '否'          '是'            '低'  
        ''        '坏'      '否'          '是'            '低'  
        ''        '坏'      '否'          '是'            '低'  
        ''        '坏'      '否'          '否'            '低'  
        ''        '坏'      '是'          '否'            '低'  
        ''        '好'      '否'          '是'            '低'  
        ''        '好'      '否'          '是'            '低'  
        ''        '坏'      '否'          '否'            '低'  
        ''        '坏'      '否'          '否'            '低'  
        ''        '好'      '否'          '否'            '低'  
        ''        '坏'      '是'          '否'            '低'  
        ''        '好'      '否'          '是'            '低'  
        ''        '好'      '否'          '否'            '低'  
        ''        '好'      '否'          '否'            '低'  }
attributes=txt(1,2:end);
activeAttributes = ones(1,length(attributes)-1);
data = txt(2:end,2:end);

%% 针对每列数据进行转换
[rows,cols] = size(data);
matrix = zeros(rows,cols);
for j=1:cols
    matrix(:,j) = cellfun(@trans2onezero,data(:,j));
end

end

function flag = trans2onezero(data)
    if strcmp(data,'坏') ||strcmp(data,'否')...
        ||strcmp(data,'低')
        flag =0;
        return ;
    end
    flag =1;
end

id3.m

function [ tree ] = id3( examples, attributes, activeAttributes )
%% ID3 算法 ,构建ID3决策树
    ...参考:https://github.com/gwheaton/ID3-Decision-Tree

% 输入参数:
% example: 输入01矩阵;
% attributes: 属性值,含有Label;
% activeAttributes: 活跃的属性值;-1,1向量,1表示活跃;

% 输出参数:
% tree:构建的决策树;

%% 提供的数据为空,则报异常
if (isempty(examples));
    error('必须提供数据!');
end

% 常量
numberAttributes = length(activeAttributes);
numberExamples = length(examples(:,1));

% 创建树节点
tree = struct('value', 'null', 'left', 'null', 'right', 'null');

% 如果最后一列全部为1,则返回“true”
lastColumnSum = sum(examples(:, numberAttributes + 1));

if (lastColumnSum == numberExamples);
    tree.value = 'true';
    return
end
% 如果最后一列全部为0,则返回“falseif (lastColumnSum == 0);
    tree.value = 'false';
    return
end

% 如果活跃的属性为空,则返回label最多的属性值
if (sum(activeAttributes) == 0);
    if (lastColumnSum >= numberExamples / 2);
        tree.value = 'true';
    else
        tree.value = 'false';
    end
    return
end

%% 计算当前属性的熵
p1 = lastColumnSum / numberExamples;
if (p1 == 0);
    p1_eq = 0;
else
    p1_eq = -1*p1*log2(p1);
end
p0 = (numberExamples - lastColumnSum) / numberExamples;
if (p0 == 0);
    p0_eq = 0;
else
    p0_eq = -1*p0*log2(p0);
end
currentEntropy = p1_eq + p0_eq;

%% 寻找最大增益
gains = -1*ones(1,numberAttributes); % 初始化增益

for i=1:numberAttributes;
    if (activeAttributes(i)) % 该属性仍处于活跃状态,对其更新
        s0 = 0; s0_and_true = 0;
        s1 = 0; s1_and_true = 0;
        for j=1:numberExamples;
            if (examples(j,i)); 
                s1 = s1 + 1;
                if (examples(j, numberAttributes + 1)); 
                    s1_and_true = s1_and_true + 1;
                end
            else
                s0 = s0 + 1;
                if (examples(j, numberAttributes + 1)); 
                    s0_and_true = s0_and_true + 1;
                end
            end
        end

        % 熵 S(v=1)
        if (~s1);
            p1 = 0;
        else
            p1 = (s1_and_true / s1); 
        end
        if (p1 == 0);
            p1_eq = 0;
        else
            p1_eq = -1*(p1)*log2(p1);
        end
        if (~s1);
            p0 = 0;
        else
            p0 = ((s1 - s1_and_true) / s1);
        end
        if (p0 == 0);
            p0_eq = 0;
        else
            p0_eq = -1*(p0)*log2(p0);
        end
        entropy_s1 = p1_eq + p0_eq;

        % 熵 S(v=0)
        if (~s0);
            p1 = 0;
        else
            p1 = (s0_and_true / s0); 
        end
        if (p1 == 0);
            p1_eq = 0;
        else
            p1_eq = -1*(p1)*log2(p1);
        end
        if (~s0);
            p0 = 0;
        else
            p0 = ((s0 - s0_and_true) / s0);
        end
        if (p0 == 0);
            p0_eq = 0;
        else
            p0_eq = -1*(p0)*log2(p0);
        end
        entropy_s0 = p1_eq + p0_eq;

        gains(i) = currentEntropy - ((s1/numberExamples)*entropy_s1) - ((s0/numberExamples)*entropy_s0);
    end
end

% 选出最大增益
[~, bestAttribute] = max(gains);
% 设置相应值
tree.value = attributes{bestAttribute};
% 去活跃状态
activeAttributes(bestAttribute) = 0;

% 根据bestAttribute把数据进行分组
examples_0= examples(examples(:,bestAttribute)==0,:);
examples_1= examples(examples(:,bestAttribute)==1,:);

% 当 value = false or 0, 左分支
if (isempty(examples_0));
    leaf = struct('value', 'null', 'left', 'null', 'right', 'null');
    if (lastColumnSum >= numberExamples / 2); % for matrix examples
        leaf.value = 'true';
    else
        leaf.value = 'false';
    end
    tree.left = leaf;
else
    % 递归
    tree.left = id3(examples_0, attributes, activeAttributes);
end
% 当 value = true or 1, 右分支
if (isempty(examples_1));
    leaf = struct('value', 'null', 'left', 'null', 'right', 'null');
    if (lastColumnSum >= numberExamples / 2); 
        leaf.value = 'true';
    else
        leaf.value = 'false';
    end
    tree.right = leaf;
else
    % 递归
    tree.right = id3(examples_1, attributes, activeAttributes);
end

% 返回
return
end

参考:
http://blog.csdn.net/lfdanding/article/details/50753239

  • 7
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是使用MATLAB实现ID3算法的示例代码: ```matlab % 定义训练数据集 data = [1 0 1 0 1; 1 0 1 1 1; 1 1 1 0 0; 0 0 1 0 1; 0 0 0 0 0; 0 1 1 0 0; 0 1 0 1 0; 0 1 0 1 1]; % 定义属性名称 attribute_names = {'Outlook', 'Temperature', 'Humidity', 'Windy'}; % 定义目标属性名称 target_attribute_name = 'PlayTennis'; % 调用ID3算法构建决策树 tree = id3(data, attribute_names, target_attribute_name); % 定义测试数据集 test_data = [1 0 1 0; 1 0 1 1; 0 1 0 1]; % 对测试数据集进行分类 for i = 1:size(test_data, 1) classification = classify(tree, attribute_names, test_data(i,:)); fprintf('Test data %d: %s\n', i, classification); end ``` 下面是ID3算法和分类函数的实现: ```matlab function tree = id3(data, attribute_names, target_attribute_name) % 获取目标属性的所有可能取值 target_attribute = data(:,end); target_attribute_values = unique(target_attribute); % 如果数据集中所有实例的目标属性取值相同,则返回单节点决策树 if numel(target_attribute_values) == 1 tree.op = ''; tree.kids = {}; tree.class = target_attribute_values(1); return; end % 如果属性集为空,则返回单节点决策树,以数据集中出现最频繁的目标属性值作为该节点的类别 if size(data, 2) == 1 tree.op = ''; tree.kids = {}; tree.class = mode(target_attribute); return; end % 计算每个属性的信息增益 [best_attribute_index, best_attribute_threshold] = choose_best_attribute(data); best_attribute_name = attribute_names{best_attribute_index}; % 构建决策树 tree.op = best_attribute_name; tree.threshold = best_attribute_threshold; tree.kids = {}; % 根据最佳属性和其阈值将数据集分割成子集 subsets = split_data(data, best_attribute_index, best_attribute_threshold); % 递归构建子树 for i = 1:numel(subsets) subset = subsets{i}; if isempty(subset) tree.kids{i} = struct('op', '', 'kids', {}, 'class', mode(target_attribute)); else subtree = id3(subset, attribute_names, target_attribute_name); tree.kids{i} = subtree; end end end function [best_attribute_index, best_attribute_threshold] = choose_best_attribute(data) % 计算目标属性的熵 target_attribute = data(:,end); target_attribute_entropy = entropy(target_attribute); % 计算每个属性的信息增益 attributes = 1:size(data,2)-1; information_gains = zeros(numel(attributes),1); thresholds = zeros(numel(attributes), 1); for i = 1:numel(attributes) attribute_index = attributes(i); attribute_values = data(:,attribute_index); [threshold, information_gain] = choose_best_threshold(attribute_values, target_attribute); information_gains(i) = information_gain; thresholds(i) = threshold; end % 选择信息增益最大的属性 [best_information_gain, best_attribute_index] = max(information_gains); best_attribute_threshold = thresholds(best_attribute_index); % 如果没有最佳阈值,则取属性值的中位数作为阈值 if isnan(best_attribute_threshold) best_attribute_values = data(:,best_attribute_index); best_attribute_threshold = median(best_attribute_values); end end function [threshold, information_gain] = choose_best_threshold(attribute_values, target_attribute) % 对属性值进行排序 [sorted_attribute_values, indices] = sort(attribute_values); sorted_target_attribute = target_attribute(indices); % 选择最佳阈值 threshold = nan; best_information_gain = -inf; for i = 1:numel(sorted_attribute_values)-1 % 计算当前阈值下的信息增益 current_threshold = (sorted_attribute_values(i) + sorted_attribute_values(i+1)) / 2; current_information_gain = information_gain(sorted_target_attribute, sorted_attribute_values, current_threshold); % 如果当前信息增益比之前的更好,则更新最佳阈值和最佳信息增益 if current_information_gain > best_information_gain threshold = current_threshold; best_information_gain = current_information_gain; end end information_gain = best_information_gain; end function subsets = split_data(data, attribute_index, threshold) % 根据属性和阈值将数据集分割成子集 attribute_values = data(:,attribute_index); left_subset_indices = attribute_values <= threshold; right_subset_indices = attribute_values > threshold; % 构建左右子集 left_subset = data(left_subset_indices,:); right_subset = data(right_subset_indices,:); subsets = {left_subset, right_subset}; end function classification = classify(tree, attribute_names, instance) % 遍历决策树,对实例进行分类 while ~isempty(tree.kids) attribute_index = find(strcmp(attribute_names, tree.op)); attribute_value = instance(attribute_index); if attribute_value <= tree.threshold tree = tree.kids{1}; else tree = tree.kids{2}; end end classification = tree.class; end function e = entropy(target_attribute) % 计算目标属性的熵 p = histc(target_attribute, unique(target_attribute)) / numel(target_attribute); p(p == 0) = []; e = -sum(p .* log2(p)); end function ig = information_gain(target_attribute, attribute_values, threshold) % 计算信息增益 n = numel(target_attribute); left_target_attribute = target_attribute(attribute_values <= threshold); right_target_attribute = target_attribute(attribute_values > threshold); left_entropy = entropy(left_target_attribute); right_entropy = entropy(right_target_attribute); p_left = numel(left_target_attribute) / n; p_right = numel(right_target_attribute) / n; ig = entropy(target_attribute) - p_left * left_entropy - p_right * right_entropy; end ``` 这个实现假设输入数据是一个矩阵,其中每行表示一个实例,每列表示一个属性,最后一列是目标属性。目标属性应该是二元的,即只有两个不同的取值。属性名称作为一个字符串向量传递,最后一个元素是目标属性名称。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值