package com.decisiontree;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
public class ID3 {
/**
* @param Spring_LGF
*/
public static void main(String[] args) {
// TODO Auto-generated method stub
//用于存储所有属性可能的取值
ArrayList<String> attrOutlook = new ArrayList<String>();
attrOutlook.add("sunny");
attrOutlook.add("overcast");
attrOutlook.add("rainy");
ArrayList<String> attrTemperature = new ArrayList<String>();
attrTemperature.add("hot");
attrTemperature.add("mild");
attrTemperature.add("cool");
ArrayList<String> attrHumidity = new ArrayList<String>();
attrHumidity.add("high");
attrHumidity.add("normal");
ArrayList<String> attrWindy = new ArrayList<String>();
attrWindy.add("true");
attrWindy.add("false");
ArrayList<String> attrPlay = new ArrayList<String>();
attrPlay.add("no");
attrPlay.add("yes");
//属性名与属性的取值进行对应
HashMap<String,ArrayList<String>> attr = new HashMap<String,ArrayList<String>>();
attr.put("outlook", attrOutlook);
attr.put("trmperature", attrTemperature);
attr.put("humidity", attrHumidity);
attr.put("windy", attrWindy);
//attr.put("play",attrPlay);
//存储属性的索引, 便于在对数据统计
HashMap<String,Integer> attrIndex = new HashMap<String,Integer>();
attrIndex.put("outlook", 0);
attrIndex.put("trmperature", 1);
attrIndex.put("humidity", 2);
attrIndex.put("windy", 3);
//attrIndex.put("play", 4);
//样本存储
String[][] data = {{"sunny","hot","high","false","no"},
{"sunny","hot","high","true","no"},
{"overcast","hot","high","false","yes"},
{"rainy","mild","high","false","yes"},
{"rainy","cool","normal","false","yes"},
{"rainy","cool","normal","true","no"},
{"overcast","cool","normal","true","yes"},
{"sunny","mild","high","false","no"},
{"sunny","cool","normal","false","yes"},
{"rainy","mild","normal","false","yes"},
{"sunny","mild","normal","true","yes"},
{"overcast","mild","high","true","yes"},
{"overcast","hot","normal","false","yes"},
{"rainy","mild","high","true","no"}};
ID3Tree root = new ID3Tree();
buildID3Tree(root,data,attr,attrIndex);
outputID3Tree(root);
}
//构造决策树
public static ID3Tree buildID3Tree(ID3Tree root, String[][] data,
HashMap<String,ArrayList<String>> attr, HashMap<String,Integer> attrIndex){
Iterator<String> attrIt = attr.keySet().iterator();
String maxAttr = null;
String attrName;//属性名称
HashMap<String, Double> attrValueList = new HashMap<String, Double>();
//用于记录每一个属性的取值在样本中出现的次数
HashMap<String, Double> attrValueMap = new HashMap<String,Double>();
while(attrIt.hasNext() && (!attr.isEmpty())){
attrName = attrIt.next();
//取得属性可能出现的取值列表
ArrayList<String> attrList = attr.get(attrName);
//取得属性的索引值
int index = attrIndex.get(attrName);
//用于扫描每一个属性的所有取值
for(int i = 0; i < attrList.size(); i++){
String attrValue = attrList.get(i);
int isPlay = 0;
int noPlay = 0;
//扫描书样本中每一个属性的取值出现的次数
for(int j = 0; j < data.length; j++){
if(data[j][index] == null){
break;
}
if(data[j][index].equals(attrValue) && data[j][4].equals("yes")){
isPlay++;
}
if(data[j][index].equals(attrValue) && data[j][4].equals("no")){
noPlay++;
}
}
double num = (-1* log(((double)isPlay/(double)(isPlay+noPlay)),2.0) * ((double)isPlay/(double)(isPlay+noPlay))) - log(((double)noPlay/(double)(isPlay+noPlay)),2.0) * ((double)noPlay/(double)(isPlay+noPlay));
//double num = ((-1)*(Math.log(isPlay/(isPlay+noPlay)) / Math.log(2.0) * isPlay / (isPlay+noPlay)) - (Math.log(noPlay/(isPlay+noPlay)) / Math.log(2.0) * noPlay / (isPlay+noPlay)));
double sum = 0.0;
if(Double.compare(num, Double.NaN) == 0){
num = 0.0;
}
attrValueMap.put(attrValue, num);
//计算每一个属性的熵值
if(attrValueList.get(attrName) == null){
attrValueList.put(attrName, num*(double)(isPlay+noPlay)/data.length);
}
else{
sum = attrValueList.get(attrName) + num*(double)(isPlay+noPlay)/data.length;
attrValueList.put(attrName, sum);
}
}
if(maxAttr == null){
maxAttr = attrName;
}
else{
if(attrValueList.get(attrName) - attrValueList.get(maxAttr) < 0.0){
maxAttr = attrName;
}
}
}
if(maxAttr != null){
int index = attrIndex.get(maxAttr);
ArrayList<String> attrList = attr.get(maxAttr);
root.attrName = maxAttr;
root.treeList = new ArrayList<ID3Tree>();
for(int i = 0; i < attrList.size(); i++){
String valueName = attrList.get(i);
double value = attrValueMap.get(valueName);
ID3Tree node = new ID3Tree();
int isPlay = 0;
int isAttr = 0;
for(int j = 0; j < data.length; j++){
if(data[j][index] == null){
break;
}
if(data[j][index].equals(valueName)){
isAttr++;
if(data[j][4].equals("yes")){
isPlay++;
}
}
}
if(value == 0.0){
node.isleaf = true;
if(isPlay == isAttr){
node.isPlay = true;
}
node.attrValue = valueName;
root.treeList.add(node);
}
else{
node.isleaf = false;
node.attrValue = valueName;
String [][]da= new String[14][4];
for(int k = 0, n = 0; k < data.length; k++){
if(data[k][index].equals(valueName)){
da[n++] = data[k];
}
}
HashMap<String,ArrayList<String>> attr2 = attr;
attr2.remove(maxAttr);
System.out.println(attr2);
buildID3Tree(node,da,attr2,attrIndex);
root.treeList.add(node);
}
}
}
return root;
}
//遍历决策树
public static void outputID3Tree(ID3Tree root){
System.out.println(root.attrName + " " + root.attrValue + " " + root.isPlay + " " + root.isleaf);
ArrayList<ID3Tree> treeList = root.treeList;
if(root.treeList != null){
for(int i = 0 ; i < treeList.size(); i++){
outputID3Tree(treeList.get(i));
}
}
}
//对数的计算,第一个参数表示的对数,第二个参数表示的是底
static public double log(double value, double base) {
return Math.log(value) / Math.log(base);
}
static class ID3Tree{
//是否是叶子节点
private boolean isleaf;
//是否出去玩,该值只有在叶子节点中出现
private boolean isPlay;
//上一个节点在该节点的取值
private String attrValue;
//孩子节点数组
private ArrayList<ID3Tree> treeList;
private String attrName;
public String getAttrName() {
return attrName;
}
public void setAttrName(String attrName) {
this.attrName = attrName;
}
public boolean isPlay() {
return isPlay;
}
public void setPlay(boolean isPlay) {
this.isPlay = isPlay;
}
public boolean isIsleaf() {
return isleaf;
}
public void setIsleaf(boolean isleaf) {
this.isleaf = isleaf;
}
public ArrayList<ID3Tree> getTreeList() {
return treeList;
}
public void setTreeList(ArrayList<ID3Tree> treeList) {
this.treeList = treeList;
}
public String getAttrValue() {
return attrValue;
}
public void setAttrValue(String attrValue) {
this.attrValue = attrValue;
}
}
}
package com.decisiontree;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
public class ID3 {
/**
* @param Spring_LGF
*/
public static void main(String[] args) {
// TODO Auto-generated method stub
//用于存储所有属性可能的取值
ArrayList<String> attrOutlook = new ArrayList<String>();
attrOutlook.add("sunny");
attrOutlook.add("overcast");
attrOutlook.add("rainy");
ArrayList<String> attrTemperature = new ArrayList<String>();
attrTemperature.add("hot");
attrTemperature.add("mild");
attrTemperature.add("cool");
ArrayList<String> attrHumidity = new ArrayList<String>();
attrHumidity.add("high");
attrHumidity.add("normal");
ArrayList<String> attrWindy = new ArrayList<String>();
attrWindy.add("true");
attrWindy.add("false");
ArrayList<String> attrPlay = new ArrayList<String>();
attrPlay.add("no");
attrPlay.add("yes");
//属性名与属性的取值进行对应
HashMap<String,ArrayList<String>> attr = new HashMap<String,ArrayList<String>>();
attr.put("outlook", attrOutlook);
attr.put("trmperature", attrTemperature);
attr.put("humidity", attrHumidity);
attr.put("windy", attrWindy);
//attr.put("play",attrPlay);
//存储属性的索引, 便于在对数据统计
HashMap<String,Integer> attrIndex = new HashMap<String,Integer>();
attrIndex.put("outlook", 0);
attrIndex.put("trmperature", 1);
attrIndex.put("humidity", 2);
attrIndex.put("windy", 3);
//attrIndex.put("play", 4);
//样本存储
String[][] data = {{"sunny","hot","high","false","no"},
{"sunny","hot","high","true","no"},
{"overcast","hot","high","false","yes"},
{"rainy","mild","high","false","yes"},
{"rainy","cool","normal","false","yes"},
{"rainy","cool","normal","true","no"},
{"overcast","cool","normal","true","yes"},
{"sunny","mild","high","false","no"},
{"sunny","cool","normal","false","yes"},
{"rainy","mild","normal","false","yes"},
{"sunny","mild","normal","true","yes"},
{"overcast","mild","high","true","yes"},
{"overcast","hot","normal","false","yes"},
{"rainy","mild","high","true","no"}};
ID3Tree root = new ID3Tree();
buildID3Tree(root,data,attr,attrIndex);
outputID3Tree(root);
}
//构造决策树
public static ID3Tree buildID3Tree(ID3Tree root, String[][] data,
HashMap<String,ArrayList<String>> attr, HashMap<String,Integer> attrIndex){
Iterator<String> attrIt = attr.keySet().iterator();
String maxAttr = null;
String attrName;//属性名称
HashMap<String, Double> attrValueList = new HashMap<String, Double>();
//用于记录每一个属性的取值在样本中出现的次数
HashMap<String, Double> attrValueMap = new HashMap<String,Double>();
while(attrIt.hasNext() && (!attr.isEmpty())){
attrName = attrIt.next();
//取得属性可能出现的取值列表
ArrayList<String> attrList = attr.get(attrName);
//取得属性的索引值
int index = attrIndex.get(attrName);
//用于扫描每一个属性的所有取值
for(int i = 0; i < attrList.size(); i++){
String attrValue = attrList.get(i);
int isPlay = 0;
int noPlay = 0;
//扫描书样本中每一个属性的取值出现的次数
for(int j = 0; j < data.length; j++){
if(data[j][index] == null){
break;
}
if(data[j][index].equals(attrValue) && data[j][4].equals("yes")){
isPlay++;
}
if(data[j][index].equals(attrValue) && data[j][4].equals("no")){
noPlay++;
}
}
double num = (-1* log(((double)isPlay/(double)(isPlay+noPlay)),2.0) * ((double)isPlay/(double)(isPlay+noPlay))) - log(((double)noPlay/(double)(isPlay+noPlay)),2.0) * ((double)noPlay/(double)(isPlay+noPlay));
//double num = ((-1)*(Math.log(isPlay/(isPlay+noPlay)) / Math.log(2.0) * isPlay / (isPlay+noPlay)) - (Math.log(noPlay/(isPlay+noPlay)) / Math.log(2.0) * noPlay / (isPlay+noPlay)));
double sum = 0.0;
if(Double.compare(num, Double.NaN) == 0){
num = 0.0;
}
attrValueMap.put(attrValue, num);
//计算每一个属性的熵值
if(attrValueList.get(attrName) == null){
attrValueList.put(attrName, num*(double)(isPlay+noPlay)/data.length);
}
else{
sum = attrValueList.get(attrName) + num*(double)(isPlay+noPlay)/data.length;
attrValueList.put(attrName, sum);
}
}
if(maxAttr == null){
maxAttr = attrName;
}
else{
if(attrValueList.get(attrName) - attrValueList.get(maxAttr) < 0.0){
maxAttr = attrName;
}
}
}
if(maxAttr != null){
int index = attrIndex.get(maxAttr);
ArrayList<String> attrList = attr.get(maxAttr);
root.attrName = maxAttr;
root.treeList = new ArrayList<ID3Tree>();
for(int i = 0; i < attrList.size(); i++){
String valueName = attrList.get(i);
double value = attrValueMap.get(valueName);
ID3Tree node = new ID3Tree();
int isPlay = 0;
int isAttr = 0;
for(int j = 0; j < data.length; j++){
if(data[j][index] == null){
break;
}
if(data[j][index].equals(valueName)){
isAttr++;
if(data[j][4].equals("yes")){
isPlay++;
}
}
}
if(value == 0.0){
node.isleaf = true;
if(isPlay == isAttr){
node.isPlay = true;
}
node.attrValue = valueName;
root.treeList.add(node);
}
else{
node.isleaf = false;
node.attrValue = valueName;
String [][]da= new String[14][4];
for(int k = 0, n = 0; k < data.length; k++){
if(data[k][index].equals(valueName)){
da[n++] = data[k];
}
}
HashMap<String,ArrayList<String>> attr2 = attr;
attr2.remove(maxAttr);
System.out.println(attr2);
buildID3Tree(node,da,attr2,attrIndex);
root.treeList.add(node);
}
}
}
return root;
}
//遍历决策树
public static void outputID3Tree(ID3Tree root){
System.out.println(root.attrName + " " + root.attrValue + " " + root.isPlay + " " + root.isleaf);
ArrayList<ID3Tree> treeList = root.treeList;
if(root.treeList != null){
for(int i = 0 ; i < treeList.size(); i++){
outputID3Tree(treeList.get(i));
}
}
}
//对数的计算,第一个参数表示的对数,第二个参数表示的是底
static public double log(double value, double base) {
return Math.log(value) / Math.log(base);
}
static class ID3Tree{
//是否是叶子节点
private boolean isleaf;
//是否出去玩,该值只有在叶子节点中出现
private boolean isPlay;
//上一个节点在该节点的取值
private String attrValue;
//孩子节点数组
private ArrayList<ID3Tree> treeList;
private String attrName;
public String getAttrName() {
return attrName;
}
public void setAttrName(String attrName) {
this.attrName = attrName;
}
public boolean isPlay() {
return isPlay;
}
public void setPlay(boolean isPlay) {
this.isPlay = isPlay;
}
public boolean isIsleaf() {
return isleaf;
}
public void setIsleaf(boolean isleaf) {
this.isleaf = isleaf;
}
public ArrayList<ID3Tree> getTreeList() {
return treeList;
}
public void setTreeList(ArrayList<ID3Tree> treeList) {
this.treeList = treeList;
}
public String getAttrValue() {
return attrValue;
}
public void setAttrValue(String attrValue) {
this.attrValue = attrValue;
}
}
}