根据统计学习方法书本的步骤
import java.util.List;
//计算经验熵
public class Entroy {
public double HD(List<Object[]> list, Object[] category){
int n = list.get(0).length;
double[] p = new double[category.length];
double HD = 0;
for(int k = 0; k <category.length; k ++){
double num = 0;
for(int i = 0; i < list.size(); i ++){
if(list.get(i)[n-1] == category[k]){
num++;
}
}
p[k] = num/list.size();
}
for(int m = 0; m < category.length; m ++){
HD = HD + (-1) * p[m] * (Math.log(p[m])/Math.log(2.0));
}
return HD;
}
}
import java.util.List;
//计算条件熵
public class Conditition {
public double GD(List<Object[]> list,List<Object[]> array, int n){
Object[] objects = array.get(n);
double gd = 0;
double[] p = new double[objects.length];
for(int k = 0; k <objects.length; k ++) {
double num = 0;
for (int i = 0; i < list.size(); i++) {
if (list.get(i)[n] == objects[k]) {
num++;
}
}
p[k] = num/list.size();
double h = HH(list,objects[k],n,array.get(array.size()-1));
gd = gd + p[k] * h;
}
return gd;
}
public double HH(List<Object[]> list, Object object, int n, Object[] category){
int m = list.get(0).length;
double HD = 0;
double[] p = new double[category.length];
for(int k = 0; k < category.length; k ++){
double num = 0, nums = 0;
for(int i = 0; i < list.size(); i ++){
if(list.get(i)[n] == object){
num ++;
if(list.get(i)[m-1] == category[k]){
nums ++;
}
}
}
p[k] = nums/num;
}
for(int j = 0; j < category.length; j ++) {
if (p[j] == 0) {
HD = 0;
} else {
HD = HD + (-1) * p[j] * (Math.log(p[j]) / Math.log(2.0));
}
}
return HD;
}
}
//输出选择的特征,返回该特征维度
public class OutPut {
public int output(double[] GD, String[] feature){
int max = 0;
for(int i = 1; i < GD.length; i ++){
if(GD[i] > GD[max]){
max = i;
}
}
System.out.println(feature[max]);
return max;
}
}
import java.util.ArrayList;
import java.util.List;
import static java.lang.Float.NaN;
//ID3算法
import java.util.ArrayList;
import java.util.List;
import static java.lang.Float.NaN;
public class ID3 {
public void id3(List<Object[]> list, String[] feature , List<Object[]> array){
Entroy e = new Entroy();
int len = list.get(0).length;
Conditition c = new Conditition();
Object[] category = array.get(len-1);
double HD = e.HD(list,category);
double[] GD = new double[array.size()-1];
double[] HDA = new double[array.size()-1];
for(int i =0; i <array.size()-1; i ++) {
HDA[i] = c.GD(list, array,i );
GD[i] = HD - HDA[i];
}
OutPut outPut = new OutPut();
int max = outPut.output(GD,feature);
List<List<Object[]>>lists = new ArrayList<>();
//将数据按照特征划分为不同区域,为下一步求熵值做准备
for(int k = 0; k < array.get(max).length; k ++) {
List<Object[]> l = new ArrayList<>();
for (int i = 0; i < list.size(); i++) {
if (list.get(i)[max] == array.get(max)[k]){
l.add(list.get(i));
}
}
lists.add(l);
}
boolean flag = false;
for(int i = 0; i <lists.size(); i ++){
double[] GD1 = new double[array.size() - 1];
double[] HDA1 = new double[array.size() - 1];
for(int k = 0; k < lists.get(i).size(); k++){
if(lists.get(i).get(k)[len-1] != lists.get(i).get(0)[len-1]) {
flag = true;
}
}
if(flag){
id3(lists.get(i),feature,array);
}
}
}
}
//测试
import java.util.ArrayList;
import java.util.List;
import static java.lang.Float.NaN;
public class test {
public static void main(String[] args) {
String[] feature = {"年龄","有工作","有自己的房子","信贷情况","类别"};
Object[] age ={"青年","中年","老年"};
Object[] work = {'是','否'};
Object[] house = {'是','否'};
Object[] loan = {"一般",'好',"非常好"};
Object[] category = {'是','否'};
List<Object[]> array = new ArrayList<Object[]>();
array.add(age);
array.add(work);
array.add(house);
array.add(loan);
array.add(category);
Object[] o1 = {age[0],work[1],house[1],loan[0],category[1]};
Object[] o2 = {age[0],work[1],house[1],loan[1],category[1]};
Object[] o3 = {age[0],work[0],house[1],loan[1],category[0]};
Object[] o4 = {age[0],work[0],house[0],loan[0],category[0]};
Object[] o5 = {age[0],work[1],house[1],loan[0],category[1]};
Object[] o6 = {age[1],work[1],house[1],loan[0],category[1]};
Object[] o7 = {age[1],work[1],house[1],loan[1],category[1]};
Object[] o8 = {age[1],work[0],house[0],loan[1],category[0]};
Object[] o9 = {age[1],work[1],house[0],loan[2],category[0]};
Object[] o10 = {age[1],work[1],house[0],loan[2],category[0]};
Object[] o11 = {age[2],work[1],house[0],loan[2],category[0]};
Object[] o12 = {age[2],work[1],house[0],loan[1],category[0]};
Object[] o13 = {age[2],work[0],house[1],loan[1],category[0]};
Object[] o14 = {age[2],work[0],house[1],loan[2],category[0]};
Object[] o15 = {age[2],work[1],house[1],loan[0],category[1]};
List<Object[]> list = new ArrayList<Object[]>();
list.add(o1);
list.add(o2);
list.add(o3);
list.add(o4);
list.add(o5);
list.add(o6);
list.add(o7);
list.add(o8);
list.add(o9);
list.add(o10);
list.add(o11);
list.add(o12);
list.add(o13);
list.add(o14);
list.add(o15);
ID3 id3 = new ID3();
id3.id3(list,feature,array);
}
}