初学大数据算法,python不熟故用java编写,编码不规范,欢迎评论交流
直接说一下代码
节点类
一个节点里面包含
no,yes 节点矩阵里结果的数量,比如这个用户买就是yes不买就是no
judgeConditions 最后建成决策树代入数据判断条件 比如是0 那么如果数据属性值是0就经过这个节点
classificationNum 这个节点分裂属性类别个数
target 分裂属性的列数
isLeaf 节点是否是叶子节点
class Node{
int no;
int yes;
int JudgeConditions;
int classificationNum;
int target; //
boolean isLeaf;
int [] Used; //那一属性已经被用作分裂属性
double [] category; //这个节点的分裂属性每一个的值 比如 0,1,2
double [] categoryNum; //每个值的个数
double[] categoryPercentage; //每个值的百分比 为了求熵值
ArrayList<Node> sonNode; //这个节点的子节点
ArrayList<Integer> rowNum; //这个节点矩阵用了原来数据那些行
}
private static void getTestData(int lineNum, Random random, int testNum, int[] used, String[][] testArray, String[][] data) {
for(int i = 0 ; i < testNum;i++){
int num = Math.abs(random.nextInt()%lineNum);
if(used[num]==0){
for(int k =0;k<data[0].length;k++){
testArray[i][k] = data[num][k];
}
used[i]=1;
}
else i--;
}
}
//获取一个属性有多少分类
private static void getClassification(String[][] data, int[] used,ArrayList judge) {
String a[]=new String[10];
int b[] = new int[10];
for(int j = 0;j<17;j++){
if(used[j]==0){
int index = 0;
for(int i = 0 ;i<data.length;i++){
boolean flag = false;
for(int h = 0;h<index;h++){
if(data[i][j].matches(a[h])){
flag = true;
b[index]++;
break;
}
}
if(!flag){
a[index++] = data[i][j];
}
}
String ret[] = new String[index];
for(int i=0 ;i<index;i++)
ret[i] = a[i];
judge.add(ret);
}else judge.add(new String[0]);
}
}
//获取建立决策树的数据
private static void getTestData(Random a, String[][] data, String[][] testData) {
for(int i =0 ;i<testData.length;i++){
System.out.println();
if(testData[i][0]==null)
testData[i] = data[Math.abs(a.nextInt()%1000)];
else i--;
}
int c1 = 0;
int c2 = 0;
for(int i=0 ;i<testData.length;i++){
if(testData[i][testData[0].length-1].matches("no")){
c1++;
}else
c2++;
}
writeFileByLines("/Users/zhinian/Downloads/sonArray.txt",testData);
}
//连续性数据的处理
public static int continuousDataDealWith(int data[][], int target, Node node) {
double max = Double.MIN_VALUE;
double maxValue =-1;
int indexRow = -1;
quikeSort(data, 0, data.length - 1,target);
ArrayList<Integer> rowNum = new ArrayList<>();
for (int i = 1; i < node.rowNum.size(); i++) {
maxValue = entropySeparate(data,node,target,i);
if(maxValue>max){
max = maxValue;
indexRow = i;
}
}
return indexRow;
}
//计算叶子节点误差//悲观剪枝里计算叶子结点偏差总值 每个加零点五
public static double errorSum(Node head) {
ArrayList<Node> queue = new ArrayList<>();
queue.add(head);
double sum = 0;
while (queue.size() != 0) {
head = queue.remove(0);
if (head.isLeaf) {
sum += Math.min(head.yes, head.no) + 0.5;
continue;
} else {
for (int i = 0; i < head.sonNode.size(); i++) {
queue.add(head.sonNode.get(i));
}
}
}
return sum;
}
//算法本体
public static Node D13(Node head, int[][] data) {
int target = new C45().nextEntropy(data, head);
if(target==-1)
return head;head.target = target;
int nowIndex = 0;
int aa[] = new int[head.classificationNum];
for (int j = 0; j < head.rowNum.size(); j++) {
boolean flag = false;
for (int k = 0; k < nowIndex; k++) {
if (aa[k] == data[head.rowNum.get(j)][target]) {
head.sonNode.get(k).rowNum.add(head.rowNum.get(j));
if (data[head.rowNum.get(j)][data[0].length - 1] == 0)
head.sonNode.get(k).no++;
else
head.sonNode.get(k).yes++;
flag = true;
break;
}
}
if (!flag) {
head.sonNode.add(InitNode(head,data[0].length));
head.sonNode.get(nowIndex).JudgeConditions=data[head.rowNum.get(j)][target];
if(target==0)
// System.out.println(data[head.rowNum.get(j)][target]);
if (data[head.rowNum.get(j)][data[0].length - 1] == 0)
head.sonNode.get(nowIndex).no++;
else
head.sonNode.get(nowIndex).y