决策树生成算法与ID3算法java实现

算法过程

数据

在这里插入图片描述
最终需要分类的属性为“电脑”,它有2个不同值0和1,1有4个样本,0有2个样本。
为计算每个属性的信息增益,我们首先给定样本电脑分类所需的期望信息:

I(4,2)=-4/6log2(4/6)-2/6log2(2/6)=0.918

从“性别”属性开始。 “性别”=1,有3个“电脑”=1,2个“电脑”=0; “性别”=0,有1个“电脑”=1,没有“电脑”=0。

i= -3/5log2(3/5)-1/5log2(1/5)-1log2(1)=0.971

按“性别”划分,则的熵为

e=5/6(-3/5log2(3/5)-1/5log2(1/5))+1/6(-1log2(1))=0.809

信息增益是

Gain(性别)=i-e=0.109

同理
Gain(学生)=0.459;
Gain(民族)=0.316;

决策树生成过程

在集合中找到信息增益最大的
{computer=1, gender=1, student=1, nation=0}
{computer=1, gender=0, student=0, nation=0}
{computer=1, gender=1, student=1, nation=0}
{computer=1, gender=1, student=1, nation=0}
{computer=0, gender=1, student=0, nation=0}
{computer=0, gender=1, student=0, nation=1}
Gain(性别)=0.109
Gain(学生)=0.459;
Gain(民族)=0.316;

选择学生分类
学生(“1”)
{computer=1, gender=1, student=1, nation=0}
{computer=1, gender=1, student=1, nation=0}
{computer=1, gender=1, student=1, nation=0}
各个样本均相同则熵为0 分类结束

学生(“0”)
{computer=1, gender=0, student=0, nation=0}
{computer=0, gender=1, student=0, nation=0}
{computer=0, gender=1, student=0, nation=1}

再次计算信息增益
Gain(性别)=0.9182958340544896
Gain(民族)=0.2516291673878229

选择性别分类
性别(“0”)
{computer=1, gender=0, student=0, nation=0}
熵为0 分类结束

性别(“1”)
{computer=0, gender=1, student=0, nation=0}
{computer=0, gender=1, student=0, nation=1}
熵为0 分类结束

决策树生成完毕

输出结果

[root<-(student:0)<-(gender:0)]:[{computer=1, gender=0, student=0, nation=0}]
[root<-(student:0)<-(gender:1)]:[{computer=0, gender=1, student=0, nation=0},{computer=0, gender=1, student=0, nation=1}]
[root<-(student:1)]:[{computer=1, gender=1, student=1, nation=0}, {computer=1, gender=1, student=1, nation=0}, {computer=1, gender=1, student=1, nation=0}

图示

在这里插入图片描述

代码

package decisiontree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

public class ID3 {
	public String nameProperty;
	public Set<String> nameSet;
	public Set<String> setProperty;
	ArrayList<SampleInter> data=new ArrayList<SampleInter> () ;
	public static void main(String[] args) {
		ID3 id3=new ID3();
	}
	private void createData() {
		data.add(new Sample("1","1","0","1"));
		data.add(new Sample("0","0","0","1"));
		data.add(new Sample("1","1","0","1"));
		data.add(new Sample("1","1","0","1"));
		data.add(new Sample("1","0","0","0"));
		data.add(new Sample("1","0","1","0"));
	}
	public ID3() {
		
		createData();
		data.forEach(System.out::println);
		createProperty("computer");
		System.out.println(nameProperty+" "+setProperty);
		nameSet=data.get(0).getKeys();
		System.out.println("nameSet"+nameSet);
		
		System.out.println(getGain(data,"gender"));
		System.out.println(getGain(data,"student"));
		System.out.println(getGain(data,"nation"));
		
		Set<String> names =new HashSet<String>(nameSet);
		names.remove(nameProperty);
		tree(new ArrayList<SampleInter>(data),names,"root");
	}
	public void tree(ArrayList<SampleInter> data,Set<String> names,String root) {
		if(data.size()==1) {
			System.out.println("["+root+"]:"+data); 
			return;
		}
		
		int  count=-1;
		String str=data.get(0).toString();
		for(SampleInter sample : data) {
			if(str.equals(sample.toString())) {
				count++;
			}
		}
		if(count==data.size()) {
			System.out.println("["+root+"]:"+data); 
			return;
		}
		String maxName="";
		double maxGain=0;
		for( String name:names) {
			double tmp = getGain(data,name);
			if(maxGain<tmp) {
				maxGain=tmp;
				maxName=name;
			}
		}
		if(maxGain<0.001||maxGain>0.999) {
			System.out.println("["+root+"]:"+data); 
			return;
		} 
		names.remove(maxName);
		for( String att:getSet(data,maxName)){
			ArrayList<SampleInter> newdata =new ArrayList<SampleInter>();
			for(SampleInter sample:data) {
				if(sample.getValue(maxName).equals(att)) {
					newdata.add(sample);
				}
			}
			tree(newdata,new HashSet<String>(names),root+"<-("+maxName+":"+att+")");
		}
	}
	public double getGain(ArrayList<SampleInter> data,String name) {
		
		double IS=getI(data,"computer");
		int size=data.size();
		double ix=0;
		double ex=0;
		for( String str:getSet(data,name)) {
			long count =  getPropertyCount(data,name,str);
			if(count==0)
				continue;
			double px=0;
			for(String value :setProperty) {
				long yes = getPropertyCount(data,name,str,nameProperty,value);
				if(yes==0)
					continue;
				px=1.0*yes/count;
				ix-=px*log2(px);
				ex-=px*log2(px)*count/size;
			}
		}
		return (IS-ex);
	}
	public double getI(ArrayList<SampleInter> data ,String name) {
		int size=data.size();
		double ix=0;
		for( String str:getSet(data,name)) {
				long count =  getPropertyCount(data,name,str);
				double px=1.0*count/size;
				ix-=px*log2(px);
		}
		return ix;
	}
	public long getPropertyCount(ArrayList<SampleInter> data,String name,String attribute) {
		return data.stream().filter(p->p.getValue(name).equals(attribute)).count();
	}
	public long getPropertyCount(ArrayList<SampleInter> data,String name1,String attribute1,String name2,String attribute2) {
		return data.stream().filter(p->p.getValue(name1).equals(attribute1)
				&&p.getValue(name2).equals(attribute2)).count();
	}
	public static double log2(double x) {
		return Math.log(x)/Math.log(2);
	}
	private void createProperty(String string) {
		nameProperty=string;
		setProperty=getSet(data,string);
	}
	public Set<String> getSet(ArrayList<SampleInter> data,String string) {
 		return data.stream().map( m->m.getValue(string)).collect(Collectors.toSet());
	}

}
interface SampleInter {
	public String getValue(String string);
	public boolean containsValue(String value);
	public Set<String> getKeys();
}
class Sample implements SampleInter{
	Map<String,String> map =new HashMap<String,String>();
	public Sample(String gender,String student,String nation,String computer ) {
		map.put("gender",gender);
		map.put("student",student);
		map.put("nation",nation);
		map.put("computer",computer);
	}
	public Set<String> getKeys() {
		return map.keySet();
	}
	public String getValue(String string) {
		return map.get(string);
	}
	public boolean containsValue(String value) {
		return map.containsValue(value);
	}
	public String toString() {
		return map.toString();
	}
}
  • 1
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值