决策树的生成是一个递归过程。在决策树基本算法中,有三种情形会导致递归返回:
当前结点包含的样本全属于同一类别,无需划分
当前属性集为空,或是所有样本在所有属性上取值相同,无法划分
(3)当前结点包含的样本集合为空,不能划分。
在第(2)中情况下,我们把当前结点标记为叶结点,并将其类别设定为该结点所含样本最多的类别,即在利用当前结点的后验分布;
在第(3)种情况下,同样把当前结点标记为叶结点,但将其类别设定为其父结点所含样本最多的类别,即把父结点的样本分布作为当前结点的先验分布。
ID3(Iterative Dichotomiser 3)是一种经典的决策树学习算法,由 Ross Quinlan 在 1986 年提出。ID3 算法主要用于解决分类问题,它通过对数据集进行递归划分来构建决策树。
ID3 算法的基本思想是在每个节点上选择最佳的特征进行分割,以使得得到的子集尽可能地“纯净”。纯净度通常用信息增益(Information Gain)或基尼指数(Gini Index)等指标来衡量,这些指标可以反映数据集的纯度或不确定性程度。
ID3 算法的步骤如下:
- 若所有实例属于同一类,则将当前节点标记为叶节点,并以该类别作为节点的类别标签。
- 若特征集为空集,或者当前节点的所有实例属于同一类,则将当前节点标记为叶节点,并以当前节点中实例数最多的类别作为节点的类别标签。
- 否则,计算每个特征的信息增益(或基尼指数),选择信息增益(或基尼指数)最大的特征作为当前节点的划分特征。
- 根据选定的划分特征将数据集划分为多个子集,并为每个子集递归地应用上述步骤,构建子节点。
matlab代码如下:main.m
clc;clear;
data_name = 'xigua'; %数据名称,
data_r = 'csv'; %数据格式
dir_ = cd; %目录,默认同文件下
%% 数据预处理
filename = fullfile([dir_ '\' data_name '.' data_r]);%文件名
% 获取属性标签
data = readtable(filename,"VariableNamingRule","preserve");
size_data = size(data); %数据大小
if isempty(data.Properties.VariableDescriptions) %英文属性值,无描述
labels = data.Properties.VariableNames(1,1:size(data,2)-1); %获取属性值,必须是英文
else %使用原始列标题以支持中文属性值
labels = cell(1,size_data(2)-1);
for i = 1:size_data(2)-1
VariableDescriptions = data.Properties.VariableDescriptions;%获取原始名称
labels{i} = VariableDescriptions{i}(9:length(VariableDescriptions{1})-1);%添加标签
end
end
% 获取数据集
opts = detectImportOptions(filename);%检查数据
opts = setvartype(opts,opts.VariableNames,'char');
data = readtable(filename,opts) %读入数据
dataset = data{:,:}; %获取数据集
% 调用函数
myTree = ID3(dataset,labels);%生成决策树,并画出来
另创一个文件ID3.m, ID3代码:
function myTree = ID3(dataset,labels)
% 输入参数:
% dataset:数据集,元胞数组或字符串数组
% labels:属性标签,元胞数组或字符串数组
myTree = createTree(dataset,labels); %生成决策树
[nodeids,nodevalue,branchvalue] = print_tree(myTree); %解析决策树
tree_plot(nodeids,nodevalue,branchvalue); %画出
end
%% 使用熵最小策略构建决策树
function myTree = createTree(dataset,labels)
% 数据为空,则报错
if(isempty(dataset))
error('必须提供数据!')
end
size_data = size(dataset);
% 数据大小与属性数量不一致,则报错
if (size_data(2)-1)~=length(labels)
error('属性数量与数据集不一致!')
end
classList = dataset(:,size_data(2));
%全为同一类,熵为0,返回
if length(unique(classList))==1
myTree = char(classList(1));
return
end
%%属性集为空,应该用找最多数的那一类,这里取值NONE
if size_data(2) == 1
myTree = 'NONE';
%myTree = char(classList(1));
return
end
% 选取特征属性
bestFeature = chooseFeature(dataset);
bestFeatureLabel = char(labels(bestFeature));
% 构建树
myTree = containers.Map;
leaf = containers.Map;
% 该属性下的不同取值
featValues = dataset(:,bestFeature);
uniqueVals = unique(featValues);
% 删除该属性
labels=[labels(1:bestFeature-1) labels(bestFeature+1:length(labels))]; %删除该属性
% 对该属性下不同取值,递归调用ID3函数
for i=1:length(uniqueVals)
subLabels = labels(:)';
value = char(uniqueVals(i));
subdata = splitDataset(dataset,bestFeature,value);%数据集分割
leaf(value) = createTree(subdata,subLabels); %递归调用
myTree(char(bestFeatureLabel)) = leaf;
end
end
%% 计算信息熵
function shannonEnt = calShannonEnt(dataset)
data_size = size(dataset);
labels = dataset(:,data_size(2));
numEntries = data_size(1);
labelCounts = containers.Map;
for i = 1:length(labels)
label = char(labels(i));
if labelCounts.isKey(label)
labelCounts(label) = labelCounts(label)+1;
else
labelCounts(label) = 1;
end
end
shannonEnt = 0.0;
for key = labelCounts.keys
key = char(key);
labelCounts(key);
prob = labelCounts(key) / numEntries;
shannonEnt = shannonEnt - prob*(log(prob)/log(2));
end
end
% 选择熵最小的属性特征
function bestFeature=chooseFeature(dataset,~)
baseEntropy = calShannonEnt(dataset);
data_size = size(dataset);
numFeatures = data_size(2) - 1;
minEntropy = 2.0;
bestFeature = 0;
for i = 1:numFeatures
uniqueVals = unique(dataset(:,i));
newEntropy = 0.0;
for j=1:length(uniqueVals)
value = uniqueVals(j);
subDataset = splitDataset(dataset,i,value);
size_sub = size(subDataset);
prob = size_sub(1)/data_size(1);
newEntropy = newEntropy + prob*calShannonEnt(subDataset);
end
if newEntropy<minEntropy
minEntropy = newEntropy;
bestFeature = i;
end
end
end
% 分割数据集,取出该特征值为value的所有样本,并去除该属性
function subDataset = splitDataset(dataset,axis,value)
subDataset = {};
data_size = size(dataset);
for i=1:data_size(1)
data = dataset(i,:);
if string(data(axis)) == string(value)
subDataset = [subDataset;[data(1:axis-1) data(axis+1:length(data))]];
end
end
end
% 层序遍历决策树,返回nodeids,nodevalue,branchvalue
function [nodeids_,nodevalue_,branchvalue_] = print_tree(tree)
nodeids(1) = 0;
nodeid = 0;
nodevalue={};
branchvalue={};
queue = {tree} ;%创建队列
while ~isempty(queue)
node = queue{1}; %取数据
queue(1) = []; %出队
if string(class(node))~="containers.Map" %叶节点
nodeid = nodeid+1;
nodevalue = [nodevalue,{node}];
elseif length(node.keys)==1 %节点
nodevalue = [nodevalue,node.keys];
node_info = node(char(node.keys));
nodeid = nodeid+1;
branchvalue = [branchvalue,node_info.keys];
for i=1:length(node_info.keys)
nodeids = [nodeids,nodeid];
end
end
if string(class(node))=="containers.Map"
keys = node.keys();
for i = 1:length(keys)
key = keys{i};
queue=[queue,{node(key)}]; %入队
end
end
nodeids_=nodeids;
nodevalue_=nodevalue;
branchvalue_ = branchvalue;
end
end
%% 参考treeplot,画图
function tree_plot(p,nodevalue,branchvalue)
[x,y,h] = treelayout(p); %x:横坐标,y:纵坐标;h:树的深度
f = find(p~=0); %非0节点
pp = p(f); %非0值
X = [x(f); x(pp); NaN(size(f))];
Y = [y(f); y(pp); NaN(size(f))];
X = X(:);
Y = Y(:);
n = length(p);
if n<500
hold on;
%plot(x,y,'ro',X,Y,'r-')
set(gcf,'Position',get(0,'ScreenSize'))
plot(X,Y,'r-');
nodesize = length(x);
for i=1:nodesize
t = text(x(i),y(i),nodevalue{1,i},'HorizontalAlignment','center');
t.EdgeColor = 'blue';
t.BackgroundColor = 'w';
end
for i=2:nodesize
j = 3*i-5;%获取连线坐标
t=text((X(j)+X(j+1))/2,(Y(j)+Y(j+1))/2,branchvalue{1,i-1},'HorizontalAlignment','center');
t.BackgroundColor = 'w';
end
hold off
else
plot(X,Y,'r-');
end
xlabel(['height = ' int2str(h)]);
axis([0 1 0 1]);
end
数据集用excel表格,表格名称用xigua.csv
仔细按照流程去设置,不能运行,请来砍我!!!!!!!!!