gdbt java_java实现gbdt

DATA类

import java.io.File;

import java.io.FileNotFoundException;

import java.util.ArrayList;

import java.util.Scanner;

public class Data {

private ArrayList> trainData=new ArrayList>();

public ArrayList> getTrainData() {

return this.trainData;

}

public Data() {

String dataPath="D://javajavajava//dbdt//src//script//data//adult.data.csv";

Scanner in;

try {

in = new Scanner(new File(dataPath));

while (in.hasNext()) {

String line=in.nextLine();

String []strs=line.trim().split(",");

ArrayList tmp=new ArrayList<>();

for(int i=0;i

{

tmp.add(strs[i]);

}

this.trainData.add(tmp);

}

} catch (FileNotFoundException e) {

// TODO Auto-generated catch block

e.printStackTrace();

}

}

public static void main(String[] args) {

// TODO Auto-generated method stub

Data d =new Data();

}

}

TREE类

import java.util.ArrayList;

import java.util.HashSet;

import java.util.Iterator;

import java.util.Random;

import java.util.spi.TimeZoneNameProvider;

public class Tree {

private Tree leftTree=new Tree();

private Tree rightTree=new Tree();

private double loss=-1;

private int attributeSplit=0;

private String attributeSplitType="";

boolean isLeaf;

double leafValue;

private ArrayList leafNodeSet=new ArrayList<>();

public ArrayList getAttributeSet(ArrayList> trainData,int idx)

{

HashSet mySet=new HashSet<>();

ArrayList ans =new ArrayList<>();

for(int i=0;i

{

mySet.add(trainData.get(i).get(idx));

}

Iterator it=mySet.iterator();

while(it.hasNext())

{

ans.add(it.next());

}

return ans;

}

public boolean myCmpLess(String str1,String str2)

{

if(Integer.parseInt(str1.trim())<=Integer.parseInt(str2.trim()))

return true;

else return false;

}

public double computeLoss(ArrayList values)

{

double loss=0;

for(int i=0;i

{

loss+=values.get(i);

}

double mean=loss/values.size();

loss=0;

for(int i=0;i

{

loss+=Math.pow(values.get(i)-mean,2);

}

return Math.sqrt(loss);

}

public double getPredictValue(int K, ArrayList subIdx,ArrayList target) {

double ans=0;

double sum=0,sum1=0;

for(int i=0;i

{

sum+=target.get(subIdx.get(i));

}

for(int i=0;i

{

sum1+=target.get(subIdx.get(i))*(1-target.get(subIdx.get(i)));

}

ans=(K-1)/K*sum/sum1;

return ans;

}

public double getPredictValue(Tree root)

{

return root.leafValue;

}

public double getPredictValue(Tree root,ArrayList instance,Boolean isDigit[])

{

if(root.isLeaf)

return root.leafValue;

else if(isDigit[root.attributeSplit])

{

if(myCmpLess(instance.get(root.attributeSplit).trim(),root.attributeSplitType))

return getPredictValue(root.leftTree, instance, isDigit);

return getPredictValue(root.rightTree, instance, isDigit);

}

else

{

if(instance.get(root.attributeSplit).trim().equals(root.attributeSplitType))

return getPredictValue(root.leftTree, instance, isDigit);

return getPredictValue(root.rightTree, instance, isDigit);

}

}

public Tree constructTree(ArrayList> leafNodes,ArrayList leafValues,int K,int splitPoints, Boolean isDigit[],ArrayList subIdx,ArrayList> trainData,ArrayList target,int maxDepth[],int depth)

{

int n=trainData.size();

int dim=trainData.get(0).size();

ArrayList leftTreeIdx=new ArrayList<>();

ArrayList rightTreeIdx=new ArrayList<>();

if(depth

{

/*

* 从所有的attribute中选取最佳的attribute,并且attribute中最佳的分割点,对数据进行分割

* */

double loss=-1;

ArrayList leftNodes=new ArrayList<>();

ArrayList rightNodes=new ArrayList<>();

int attributeSplit=0;

String attributeSplitType="";

for(int i=0;i

{

//得到该attribute下所有的distinct的值

ArrayList myAttributeSet=new ArrayList<>();

ArrayList subDigitAttribute=new ArrayList<>();

myAttributeSet=getAttributeSet(trainData, i);

if(isDigit[i])//如果是数字,就从数组中随机选取splitpoints个节点,代表这个属性可以在这splitpoints下进行分割

{

while(subDigitAttribute.size()

{

Random r=new Random();

int tmp=r.nextInt(myAttributeSet.size());

subDigitAttribute.add(myAttributeSet.get(tmp));

myAttributeSet.clear();

myAttributeSet=subDigitAttribute;

}

}

for(int j=0;j

{

for(int k=0;k

{

if((!isDigit[i]&&trainData.get(subIdx.get(k)).get(i).trim().equals(myAttributeSet.get(j)))||(isDigit[i]&&myCmpLess(trainData.get(subIdx.get(k)).get(i),myAttributeSet.get(j))))

{

leftTreeIdx.add(subIdx.get(k));

}

else

{

rightTreeIdx.add(subIdx.get(k));

}

}

ArrayList leftTarget=new ArrayList<>();

ArrayList rightTarget=new ArrayList<>();

for(int k=0;k

leftTarget.add(target.get(leftTreeIdx.get(k)));

for(int k=0;k

rightTarget.add(target.get(rightTreeIdx.get(k)));

double lossTmp=computeLoss(leftTarget)+computeLoss(rightTarget);

if(loss<0||loss

{

leftNodes.clear();

rightNodes.clear();

for(int k=0;k

leftNodes.add(leftTreeIdx.get(k));

for(int k=0;k

rightNodes.add(rightTreeIdx.get(k));

attributeSplit=i;

attributeSplitType=myAttributeSet.get(j);

}

}

}

Tree tmpTree=new Tree();

tmpTree.attributeSplit=attributeSplit;

tmpTree.attributeSplitType=attributeSplitType;

tmpTree.loss=loss;

tmpTree.isLeaf=false;

tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, leftNodes, trainData, target, maxDepth, depth+1);

tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, rightNodes, trainData, target, maxDepth, depth+1);

return tmpTree;

}

else

{

Tree tmpTree=new Tree();

tmpTree.isLeaf=true;

tmpTree.leafValue=getPredictValue(K, subIdx, target);

for(int i=0;i

tmpTree.leafNodeSet.add(subIdx.get(i));

leafNodes.add(subIdx);

leafValues.add(tmpTree.leafValue);

return tmpTree;

}

}

public static void main(String[] args) {

// TODO Auto-generated method stub

Tree aTree=new Tree();

}

}

GBDT类

import java.rmi.server.SkeletonNotFoundException;

import java.util.ArrayList;

import java.util.HashSet;

import java.util.Iterator;

import java.util.Map;

import java.util.Map.Entry;

import java.util.Random;

import java.util.Set;

public class GBDT {

private ArrayList> datas=new ArrayList>();

private ArrayList labelSets=new ArrayList<>();

private ArrayList> F=new ArrayList>();

private ArrayList> residual=new ArrayList>();

private ArrayList> trainData=new ArrayList>();

private ArrayList labelTrainData=new ArrayList();

private int K;

private Boolean isDigit[];

private int dim;

private int n;

private double learningRate;

private ArrayList> trees=new ArrayList>(); //存放所有的树

private int max_iter;

private double sampleRate;

private int maxDepth;

private int splitPoints;

public void computeResidual(ArrayList subId)

{

for(int i=0;i

{

int idx=subId.get(i);

int y=0;

if(this.labelTrainData.get(idx)==-1) y=0;

else y=1;

double sum=Math.exp(this.F.get(idx).get(0))+Math.exp(this.F.get(idx).get(1));

double p1=Math.exp(this.F.get(idx).get(0))/sum,p2=Math.exp(this.F.get(idx).get(1))/sum;

this.residual.get(idx).set(0, y-p1);

this.residual.get(idx).set(1, y-p2);

}

}

public ArrayList myrandom(int maxNum,int num)

{

ArrayList ans=new ArrayList<>();

Set mySet=new HashSet<>();

while(mySet.size()

{

Random r=new Random();

int tmp=r.nextInt(maxNum);

mySet.add(tmp);

}

Iterator it=mySet.iterator();

while(it.hasNext())

{

ans.add(it.next());

}

return ans;

}

public GBDT()

{

this.max_iter=50;

this.sampleRate=0.8;

this.K=2;//2分类问题

this.maxDepth=6;

this.splitPoints=3;

this.learningRate=0.01;

getData();

}

public void train()

{

for(int i=0;i

{

ArrayList subSet=new ArrayList<>();

int numSubset=(int)(this.n*this.sampleRate);

subSet=myrandom(this.n,numSubset);

computeResidual(subSet);

ArrayList target=new ArrayList<>();

ArrayList tmpTree=new ArrayList<>();

int maxdepths[]={this.maxDepth};

for(int j=0;j

{

target.clear();

for(int k=0;k

{

target.add(residual.get(subSet.get(k)).get(j));

}

ArrayList> leafNodes=new ArrayList>();

ArrayList leafValues=new ArrayList<>();

Tree treeSub=new Tree();

Tree iterTree=treeSub.constructTree(leafNodes,leafValues,K,splitPoints, isDigit, subSet, trainData, target,maxdepths,0);

tmpTree.add(iterTree);

updateFvalue(isDigit, subSet,leafNodes,leafValues,j,iterTree);

}

trees.add(tmpTree);

}

}

public void updateFvalue(Boolean isDigit[], ArrayList subIdx,ArrayList> leafNodes,ArrayList leafValues,int label,Tree root)

{

ArrayList remainIdx=new ArrayList<>();

int arr[]=new int[this.n];

for(int i=0;i

arr[i]=i;

for(int i=0;i

{

arr[subIdx.get(i)]=-1;

}

//求出不是用来训练树的余下集合

for(int i=0;i

{

if(arr[i]!=-1)

remainIdx.add(i);

}

for(int i=0;i

{

for(int j=0;j

{

this.F.get(leafNodes.get(i).get(j)).set(label, this.F.get(leafNodes.get(i).get(j)).get(label)+this.learningRate*root.getPredictValue(root));

}

}

for(int i=0;i

{

double leafV=root.getPredictValue(root,this.trainData.get(remainIdx.get(i)),isDigit);

this.F.get(remainIdx.get(i)).set(label, this.F.get(remainIdx.get(i)).get(label)+this.learningRate*leafV);

}

}

public boolean checkDigit(String str) {

for(int i=0;i

{

if(!(str.charAt(i)>='0'&&str.charAt(i)<='9'))

{

return false;

}

}

return true;

}

public void getData() {

Data d =new Data();

this.datas=d.getTrainData();

this.dim=this.datas.get(0).size()-1;

this.isDigit=new Boolean[this.dim];

//遍历所有样本,去掉中间含有不是正常的数据

for(int i=0;i

labelSets.add(this.datas.get(0).get(i));

//保证数据的第一行是正确的,来判断,特征哪些纬度是数字,哪些纬度是字符串

for(int i=0;i

{

if(checkDigit(this.datas.get(0).get(i)))

this.isDigit[i]=true;

else this.isDigit[i]=false;

}

//如果字符串==?说明是异常数据,这里做数据的清理

for(int i=1;i

{

ArrayList tmp=new ArrayList<>();

boolean flag=true;

for(int j=0;j

{

if(datas.get(i).get(j).trim().equals("?"))

{

flag=false;

break;

}

}

if(!flag) continue;

if(datas.get(i).get(this.dim).trim().equals("?")) continue;

trainData.add(tmp);

if(datas.get(i).get(this.dim).trim().equals("<=50K"))

labelTrainData.add(-1);

else

labelTrainData.add(1);

}

this.n=this.labelTrainData.size();

for(int i=0;i

labelSets.add(this.datas.get(0).get(i));

//初始化F矩阵为全0,F矩阵是n*2,是2分类问题,如果要多分类,改下这里就可以了

for(int i=0;i

{

ArrayList arrTmp=new ArrayList();

for(int j=0;j<2;j++)

{

arrTmp.add(0.0);

}

this.F.add(arrTmp);

this.residual.add(arrTmp);

}

}

public static void main(String[] args) {

GBDT dGbdt=new GBDT();

dGbdt.getData();

System.err.println(dGbdt.n);

}

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值