最小二乘回归树 属性选择比率 gbdt基分类器 java

这个Java程序实现了最小二乘回归树(Cart)的数据加载、节点处理、特征选择和数据分割功能。通过设定最大深度、最小叶子节点数和属性选择比率,进行模型训练。使用了随机数种子确保可复现性,并提供了数据加载和分类预测的方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

/**
 * 最小二乘回归树      缺失属性还没想好怎么处理好
 * @author ysh  1208706282
 *
 */
public class Cart {
    static double MISSINGDATA = -111111111;
    int mMaxDepth;              //设定的最大深度
    int mMinLeaf;               //节点最小样本数
    double mFeatureRate;        //属性选择比率
    List<Sample> mSamples;
    Random mRandom;
    Node mParent;                //回归树根节点
    static class Sample{
        Double label;
        List<Double> feature;
    }
    static class Node{
        List<Sample> samples;
        int depth;
        int featureId;
        double splitValue;
        double fitness;
        double predict;
        Node childs[];
        boolean leaf;
    }
    /**
     * 加载数据   回归树
     * @param path
     * @param regex
     * @throws Exception
     */
    public  void loadData(String path,String regex) throws Exception{
        mSamples = new ArrayList<Sample>();
        BufferedReader reader = new BufferedReader(new FileReader(path));
        String line = null;
        String splits[] = null;
        Sample sample = null;
        while(null != (line=reader.readLine())){
            splits = line.split(regex);
            sample = new Sample();
            sample.label = Double.valueOf(splits[0]);
            
            sample.feature = new ArrayList<Double>(splits.length-1);
            for(int i=0;i<splits.length-1;i++){
                sample.feature.add(new Double(splits[i+1]));
            }
            
            mSamples.add(sample);
        }
        reader.close();
    }
    public void setData(List<Sample> samples){
        this.mSamples = samples;
    }
    /**
     * 加载验证测试集
     * @param path
     * @param regex
     * @throws Exception
     */
    public  static List<Sample> loadTestData(String path,boolean hasLabel,String regex) throws Exception{
        List<Sample> samples = new ArrayList<Sample>();
        BufferedReader reader = new BufferedReader(new FileReader(path));
        String line = null;
        String splits[] = null;
        Sample sample = null;
        while(null != (line=reader.readLine())){
            splits = line.split(regex);
            sample = new Sample();
            if(hasLabel){
                sample.label = Double.valueOf(splits[0]);
            }
            
            sample.feature = new ArrayList<Double>(splits.length-1);
            for(int i=0;i<splits.length-1;i++){
                sample.feature.add(new Double(splits[i+1]));
            }
            
            samples.add(sample);
        }
        reader.close();
        return samples;
    }
    /**
     * 求节点均值
     * @param samples
     * @return
     */
    double getAverage(Sample []samples){
        double avg = 0;
        for(Sample sample:samples){
            avg += sample.label;
        }
        return avg/samples.length;
    }
    /**
     * 判断样本标签是否一致
     * @param samples
     * @return
     */
    boolean isSame(Sample []samples){
        double label = samples[0].label;
        boolean issame = true;
        for(Sample sample:samples){
            if((label - sample.label)<1E-5){
                issame = false;
                break;
            }
        }
        return issame;
    }
    /**
     * 指定特征的不纯度 越小越好
     * @param samples
     * @param featIndex
     * @param node
     * @return
     */
    int getImpurity(Sample[] samples,final int featIndex,Node node){
        long start = System.currentTimeMillis();
        //System.out.println("getImpurity  "+featIndex+" "+start);
        int ret = 0;
        Arrays.sort(samples, new Comparator<Sample>(){

            @Override
            public int compare(Sample o1, Sample o2) {
                // TODO Auto-generated method stub
                int ret = 0;
                if(o1.feature.get(featIndex) < o2.feature.get(featIndex)){
                    ret = -1;
                }else{
                    ret = 1;
                }
                return ret;
            }});
        double ts = 0;
        double ls = 0;
        double rs = 0;
        double avgSplit = 0;
        for(Sample s:samples){
            ts += s.label;
        }
        ls += samples[0].label;
        double bestFitness = Double.MAX_VALUE;
        double bestSplit = 0;
        for(int i=1;i<samples.length;i++){
            ls += samples[i].label;
            if((samples[i].feature.get(featIndex)-samples[i-1].feature.get(featIndex))<1E-4){
                continue;
            }
            //System.out.println("getImpurity  "+featIndex+" "+(System.currentTimeMillis()-start)/1000);
            ls -= samples[i].label;
            double lavg = ls/i;
            double lerror = 0;
            for(int j=0;j<i;j++){
                lerror += (samples[j].label-lavg)*(samples[j].label-lavg);
            }
            double ravg = (ts-ls)/(samples.length-i);
            double rerror = 0;
            for(int j=i;j<samples.length;j++){
                rerror += (samples[j].label-ravg)*(samples[j].label-ravg);
            }
            if(bestFitness > (lerror+rerror)){
                bestFitness = lerror+rerror;
                bestSplit = (samples[i].feature.get(featIndex) + samples[i-1].feature.get(featIndex))/2;
            }
        }
        node.fitness = bestFitness;
        node.splitValue = bestSplit;
        return bestFitness!=Double.MAX_VALUE? 0:1;
    }
    /**
     * 找到最佳切分属性及其分割点
     * @param samples
     * @param node
     * @return
     */
    int findSplit(Sample []samples,Node node){
        int ret = 0;
        int featureIndex[] = new int[samples[0].feature.size()];
        for(int i=0;i<samples[0].feature.size();i++){
            featureIndex[i] = i;
        }
        int index = -1;
        for(int i=0;i<samples[0].feature.size();i++){
            index = mRandom.nextInt(samples[0].feature.size());
            featureIndex[i] = featureIndex[i]^featureIndex[index];
            featureIndex[index] = featureIndex[i]^featureIndex[index];
            featureIndex[i] = featureIndex[i]^featureIndex[index];
        }
        int bestFeatIdx = 0;
        double bestFitness = Double.MAX_VALUE;
        double bestSplitValue = 0;
        for(int feat=0;feat<featureIndex.length*mFeatureRate;feat++){
            int idx = featureIndex[feat];
            ret = getImpurity(samples,idx,node);
            if(ret != 0){
                continue;
            }
            if(bestFitness > node.fitness){
                bestFitness = node.fitness;
                bestFeatIdx = idx;
                bestSplitValue = node.splitValue;
            }
        }
        node.fitness = bestFitness;
        node.featureId = bestFeatIdx;
        node.splitValue = bestSplitValue;
        return bestFitness!=Double.MAX_VALUE ? 0:1;
    }
    /**
     * 分割数据
     * @param samples
     * @param node
     */
    public void splitData(Sample []samples,Node node){
        node.childs = new Node[3];
        for(int i=0;i<3;i++){
            node.childs[i] = new Node();
            node.childs[i].depth = node.depth+1;
            node.childs[i].samples = new ArrayList<Sample>();
        }
        int feat = node.featureId;
        for(Sample s:samples){
            if(s.feature.get(feat) == Cart.MISSINGDATA){
                node.childs[2].samples.add(s);
                continue;
            }
            if(s.feature.get(feat) < node.splitValue){
                node.childs[0].samples.add(s);
            }else{
                node.childs[1].samples.add(s);
            }
        }
    }
    /**
     * 递归训练创建
     * @param samples
     * @param node
     */
    public void fit(Sample []samples,Node node){
        node.predict = getAverage(samples);
        if((node.depth==mMaxDepth) || isSame(samples) || samples.length<mMinLeaf){
            node.leaf = true;
            return;
        }
        int ret = 0;
        ret = findSplit(samples,node);
        if(ret != 0){
            node.leaf = true;
            return;
        }
        splitData(samples,node);
        if(node.childs[0].samples.isEmpty() || node.childs[1].samples.isEmpty()){
            node.leaf = true;
            return;
        }
        Sample []s = null;
        for(int i=0;i<3;i++){
            s = new Sample[node.childs[i].samples.size()];
            for(int j=0;j<s.length;j++){
                s[j] = node.childs[i].samples.get(j);
            }
            if(s.length != 0){
                fit(s,node.childs[i]);
            }
        }
    }
    /**
     * 训练
     */
    public void train(){
        mParent = new Node();
        mParent.samples = mSamples;
        mParent.depth = 0;
        Sample []s = new Sample[mSamples.size()];
        for(int i=0;i<s.length;i++){
            s[i] = mSamples.get(i);
        }
        fit(s,mParent);
    }
    /**
     * 分类
     * @param sample
     * @return
     */
    public double classify(Sample sample){
        return classify(mParent,sample);
    }
    /**
     * 分类
     * @param node
     * @param sample
     * @return
     */
    public double classify(Node node,Sample sample){
        if(node.leaf == true){
            return node.predict;
        }
        int fea = node.featureId;
        if(sample.feature.get(fea) == Cart.MISSINGDATA){
            return classify(node.childs[2],sample);
        }
        if(sample.feature.get(fea) < node.splitValue){
            return classify(node.childs[0],sample);
        }else{
            return classify(node.childs[1],sample);
        }
    }
    /**
     * @param args
     * @throws Exception
     */
    public static void main(String[] args) throws Exception {
        // TODO Auto-generated method stub
        Random ran = new Random();
        ran.setSeed(10001);
        for(int i=0;i<10;i++){
            System.out.println(ran.nextInt(10));
        }
        Cart cart = new Cart();
        cart.mFeatureRate = 0.8;
        cart.mMaxDepth = 6;
        cart.mMinLeaf = 1;
        cart.mRandom = new Random();
        cart.mRandom.setSeed(100);
        cart.loadData("F:/2016-contest/20161001/train_data_1.csv", ",");
        System.out.println(System.currentTimeMillis());
        cart.train();
        List<Sample> samples = cart.loadTestData("F:/2016-contest/20161001/valid_data_1.csv", true, ",");
        double sum = 0;
        for(Sample s:samples){
            double val = cart.classify(s);
            sum += (val-s.label)*(val-s.label);
            System.out.println(cart.classify(s)+"  "+s.label);
        }
        System.out.println(sum/samples.size());
        System.out.println(System.currentTimeMillis());
    }

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值