DATA类
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Scanner;
public class Data {
private ArrayList> trainData=new ArrayList>();
public ArrayList> getTrainData() {
return this.trainData;
}
public Data() {
String dataPath="D://javajavajava//dbdt//src//script//data//adult.data.csv";
Scanner in;
try {
in = new Scanner(new File(dataPath));
while (in.hasNext()) {
String line=in.nextLine();
String []strs=line.trim().split(",");
ArrayList tmp=new ArrayList<>();
for(int i=0;i
{
tmp.add(strs[i]);
}
this.trainData.add(tmp);
}
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
public static void main(String[] args) {
// TODO Auto-generated method stub
Data d =new Data();
}
}
TREE类
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.spi.TimeZoneNameProvider;
public class Tree {
private Tree leftTree=new Tree();
private Tree rightTree=new Tree();
private double loss=-1;
private int attributeSplit=0;
private String attributeSplitType="";
boolean isLeaf;
double leafValue;
private ArrayList leafNodeSet=new ArrayList<>();
public ArrayList getAttributeSet(ArrayList> trainData,int idx)
{
HashSet mySet=new HashSet<>();
ArrayList ans =new ArrayList<>();
for(int i=0;i
{
mySet.add(trainData.get(i).get(idx));
}
Iterator it=mySet.iterator();
while(it.hasNext())
{
ans.add(it.next());
}
return ans;
}
public boolean myCmpLess(String str1,String str2)
{
if(Integer.parseInt(str1.trim())<=Integer.parseInt(str2.trim()))
return true;
else return false;
}
public double computeLoss(ArrayList values)
{
double loss=0;
for(int i=0;i
{
loss+=values.get(i);
}
double mean=loss/values.size();
loss=0;
for(int i=0;i
{
loss+=Math.pow(values.get(i)-mean,2);
}
return Math.sqrt(loss);
}
public double getPredictValue(int K, ArrayList subIdx,ArrayList target) {
double ans=0;
double sum=0,sum1=0;
for(int i=0;i
{
sum+=target.get(subIdx.get(i));
}
for(int i=0;i
{
sum1+=target.get(subIdx.get(i))*(1-target.get(subIdx.get(i)));
}
ans=(K-1)/K*sum/sum1;
return ans;
}
public double getPredictValue(Tree root)
{
return root.leafValue;
}
public double getPredictValue(Tree root,ArrayList instance,Boolean isDigit[])
{
if(root.isLeaf)
return root.leafValue;
else if(isDigit[root.attributeSplit])
{
if(myCmpLess(instance.get(root.attributeSplit).trim(),root.attributeSplitType))
return getPredictValue(root.leftTree, instance, isDigit);
return getPredictValue(root.rightTree, instance, isDigit);
}
else
{
if(instance.get(root.attributeSplit).trim().equals(root.attributeSplitType))
return getPredictValue(root.leftTree, instance, isDigit);
return getPredictValue(root.rightTree, instance, isDigit);
}
}
public Tree constructTree(ArrayList> leafNodes,ArrayList leafValues,int K,int splitPoints, Boolean isDigit[],ArrayList subIdx,ArrayList> trainData,ArrayList target,int maxDepth[],int depth)
{
int n=trainData.size();
int dim=trainData.get(0).size();
ArrayList leftTreeIdx=new ArrayList<>();
ArrayList rightTreeIdx=new ArrayList<>();
if(depth
{
/*
* 从所有的attribute中选取最佳的attribute,并且attribute中最佳的分割点,对数据进行分割
* */
double loss=-1;
ArrayList leftNodes=new ArrayList<>();
ArrayList rightNodes=new ArrayList<>();
int attributeSplit=0;
String attributeSplitType="";
for(int i=0;i
{
//得到该attribute下所有的distinct的值
ArrayList myAttributeSet=new ArrayList<>();
ArrayList subDigitAttribute=new ArrayList<>();
myAttributeSet=getAttributeSet(trainData, i);
if(isDigit[i])//如果是数字,就从数组中随机选取splitpoints个节点,代表这个属性可以在这splitpoints下进行分割
{
while(subDigitAttribute.size()
{
Random r=new Random();
int tmp=r.nextInt(myAttributeSet.size());
subDigitAttribute.add(myAttributeSet.get(tmp));
myAttributeSet.clear();
myAttributeSet=subDigitAttribute;
}
}
for(int j=0;j
{
for(int k=0;k
{
if((!isDigit[i]&&trainData.get(subIdx.get(k)).get(i).trim().equals(myAttributeSet.get(j)))||(isDigit[i]&&myCmpLess(trainData.get(subIdx.get(k)).get(i),myAttributeSet.get(j))))
{
leftTreeIdx.add(subIdx.get(k));
}
else
{
rightTreeIdx.add(subIdx.get(k));
}
}
ArrayList leftTarget=new ArrayList<>();
ArrayList rightTarget=new ArrayList<>();
for(int k=0;k
leftTarget.add(target.get(leftTreeIdx.get(k)));
for(int k=0;k
rightTarget.add(target.get(rightTreeIdx.get(k)));
double lossTmp=computeLoss(leftTarget)+computeLoss(rightTarget);
if(loss<0||loss
{
leftNodes.clear();
rightNodes.clear();
for(int k=0;k
leftNodes.add(leftTreeIdx.get(k));
for(int k=0;k
rightNodes.add(rightTreeIdx.get(k));
attributeSplit=i;
attributeSplitType=myAttributeSet.get(j);
}
}
}
Tree tmpTree=new Tree();
tmpTree.attributeSplit=attributeSplit;
tmpTree.attributeSplitType=attributeSplitType;
tmpTree.loss=loss;
tmpTree.isLeaf=false;
tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, leftNodes, trainData, target, maxDepth, depth+1);
tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, rightNodes, trainData, target, maxDepth, depth+1);
return tmpTree;
}
else
{
Tree tmpTree=new Tree();
tmpTree.isLeaf=true;
tmpTree.leafValue=getPredictValue(K, subIdx, target);
for(int i=0;i
tmpTree.leafNodeSet.add(subIdx.get(i));
leafNodes.add(subIdx);
leafValues.add(tmpTree.leafValue);
return tmpTree;
}
}
public static void main(String[] args) {
// TODO Auto-generated method stub
Tree aTree=new Tree();
}
}
GBDT类
import java.rmi.server.SkeletonNotFoundException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set;
public class GBDT {
private ArrayList> datas=new ArrayList>();
private ArrayList labelSets=new ArrayList<>();
private ArrayList> F=new ArrayList>();
private ArrayList> residual=new ArrayList>();
private ArrayList> trainData=new ArrayList>();
private ArrayList labelTrainData=new ArrayList();
private int K;
private Boolean isDigit[];
private int dim;
private int n;
private double learningRate;
private ArrayList> trees=new ArrayList>(); //存放所有的树
private int max_iter;
private double sampleRate;
private int maxDepth;
private int splitPoints;
public void computeResidual(ArrayList subId)
{
for(int i=0;i
{
int idx=subId.get(i);
int y=0;
if(this.labelTrainData.get(idx)==-1) y=0;
else y=1;
double sum=Math.exp(this.F.get(idx).get(0))+Math.exp(this.F.get(idx).get(1));
double p1=Math.exp(this.F.get(idx).get(0))/sum,p2=Math.exp(this.F.get(idx).get(1))/sum;
this.residual.get(idx).set(0, y-p1);
this.residual.get(idx).set(1, y-p2);
}
}
public ArrayList myrandom(int maxNum,int num)
{
ArrayList ans=new ArrayList<>();
Set mySet=new HashSet<>();
while(mySet.size()
{
Random r=new Random();
int tmp=r.nextInt(maxNum);
mySet.add(tmp);
}
Iterator it=mySet.iterator();
while(it.hasNext())
{
ans.add(it.next());
}
return ans;
}
public GBDT()
{
this.max_iter=50;
this.sampleRate=0.8;
this.K=2;//2分类问题
this.maxDepth=6;
this.splitPoints=3;
this.learningRate=0.01;
getData();
}
public void train()
{
for(int i=0;i
{
ArrayList subSet=new ArrayList<>();
int numSubset=(int)(this.n*this.sampleRate);
subSet=myrandom(this.n,numSubset);
computeResidual(subSet);
ArrayList target=new ArrayList<>();
ArrayList tmpTree=new ArrayList<>();
int maxdepths[]={this.maxDepth};
for(int j=0;j
{
target.clear();
for(int k=0;k
{
target.add(residual.get(subSet.get(k)).get(j));
}
ArrayList> leafNodes=new ArrayList>();
ArrayList leafValues=new ArrayList<>();
Tree treeSub=new Tree();
Tree iterTree=treeSub.constructTree(leafNodes,leafValues,K,splitPoints, isDigit, subSet, trainData, target,maxdepths,0);
tmpTree.add(iterTree);
updateFvalue(isDigit, subSet,leafNodes,leafValues,j,iterTree);
}
trees.add(tmpTree);
}
}
public void updateFvalue(Boolean isDigit[], ArrayList subIdx,ArrayList> leafNodes,ArrayList leafValues,int label,Tree root)
{
ArrayList remainIdx=new ArrayList<>();
int arr[]=new int[this.n];
for(int i=0;i
arr[i]=i;
for(int i=0;i
{
arr[subIdx.get(i)]=-1;
}
//求出不是用来训练树的余下集合
for(int i=0;i
{
if(arr[i]!=-1)
remainIdx.add(i);
}
for(int i=0;i
{
for(int j=0;j
{
this.F.get(leafNodes.get(i).get(j)).set(label, this.F.get(leafNodes.get(i).get(j)).get(label)+this.learningRate*root.getPredictValue(root));
}
}
for(int i=0;i
{
double leafV=root.getPredictValue(root,this.trainData.get(remainIdx.get(i)),isDigit);
this.F.get(remainIdx.get(i)).set(label, this.F.get(remainIdx.get(i)).get(label)+this.learningRate*leafV);
}
}
public boolean checkDigit(String str) {
for(int i=0;i
{
if(!(str.charAt(i)>='0'&&str.charAt(i)<='9'))
{
return false;
}
}
return true;
}
public void getData() {
Data d =new Data();
this.datas=d.getTrainData();
this.dim=this.datas.get(0).size()-1;
this.isDigit=new Boolean[this.dim];
//遍历所有样本,去掉中间含有不是正常的数据
for(int i=0;i
labelSets.add(this.datas.get(0).get(i));
//保证数据的第一行是正确的,来判断,特征哪些纬度是数字,哪些纬度是字符串
for(int i=0;i
{
if(checkDigit(this.datas.get(0).get(i)))
this.isDigit[i]=true;
else this.isDigit[i]=false;
}
//如果字符串==?说明是异常数据,这里做数据的清理
for(int i=1;i
{
ArrayList tmp=new ArrayList<>();
boolean flag=true;
for(int j=0;j
{
if(datas.get(i).get(j).trim().equals("?"))
{
flag=false;
break;
}
}
if(!flag) continue;
if(datas.get(i).get(this.dim).trim().equals("?")) continue;
trainData.add(tmp);
if(datas.get(i).get(this.dim).trim().equals("<=50K"))
labelTrainData.add(-1);
else
labelTrainData.add(1);
}
this.n=this.labelTrainData.size();
for(int i=0;i
labelSets.add(this.datas.get(0).get(i));
//初始化F矩阵为全0,F矩阵是n*2,是2分类问题,如果要多分类,改下这里就可以了
for(int i=0;i
{
ArrayList arrTmp=new ArrayList();
for(int j=0;j<2;j++)
{
arrTmp.add(0.0);
}
this.F.add(arrTmp);
this.residual.add(arrTmp);
}
}
public static void main(String[] args) {
GBDT dGbdt=new GBDT();
dGbdt.getData();
System.err.println(dGbdt.n);
}
}