java实现孤立森林

理论非常简单,不作赘述。

[结果]

score > 0.6和10颗树情况下,

score > 0.51和10颗树情况下

[数据]

链接:https://pan.baidu.com/s/1KW-g-mg00UzhYvtXe1vM7w 
提取码:q6t6 
复制这段内容后打开百度网盘手机App,操作更方便哦

[代码]

package IsoForest;

import org.ejml.data.DenseMatrix64F;

import java.io.*;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Random;


class ITree{
}

class ITreeBranch extends ITree{
    ITree left;
    ITree right;
    double splitValue;
    int splitAttr;


    public ITreeBranch(ITree left,ITree right,double splitValue,int splitAttr){
        this.left = left;
        this.right = right;
        this.splitValue = splitValue;
        this.splitAttr = splitAttr;
    }

    public ITree getLeft() {
        return left;
    }

    public void setLeft(ITree left) {
        this.left = left;
    }

    public ITree getRight() {
        return right;
    }

    public void setRight(ITree right) {
        this.right = right;
    }

    public double getSplitValue() {
        return splitValue;
    }

    public void setSplitValue(double splitValue) {
        this.splitValue = splitValue;
    }

    public int getSplitAttr() {
        return splitAttr;
    }

    public void setSplitAttr(int splitAttr) {
        this.splitAttr = splitAttr;
    }
}

class ITreeLeaf extends ITree{
    int size;

    public ITreeLeaf(int size){
        this.size = size;
    }

    public int getSize() {
        return size;
    }

    public void setSize(int size) {
        this.size = size;
    }
}

class IForest{
    List<ITree> iTrees;
    int maxSamples;

    public IForest(List<ITree> iTrees, int maxSamples) {
        this.iTrees = iTrees;
        this.maxSamples = maxSamples;
    }

    public double predict(DenseMatrix64F x){
        if(iTrees.size() == 0 || iTrees == null){
            throw new IllegalArgumentException("请训练后再预测");
        }

        double sum = 0;
        for(int i = 0;i < iTrees.size();i++){
            sum += pathLengh(x,iTrees.get(i),0);
        }

        double exponent = -(sum/iTrees.size())/cost(maxSamples);

        double score = Math.pow(2,exponent);

        if(score > 0.6){
            return -1;
        }
        else {
            return 1;
        }
    }

    public double pathLengh(DenseMatrix64F x,ITree tree,int path_length){
        String simpleName = tree.getClass().getSimpleName();
        if(simpleName.equals("ITreeLeaf")){
            ITreeLeaf leaf = (ITreeLeaf) tree;
            int size = leaf.getSize();
            return path_length + cost(size);

        }

        ITreeBranch iTreeBranch = (ITreeBranch)tree;
        int splitAttr = iTreeBranch.getSplitAttr();
        double splitValue = iTreeBranch.getSplitValue();

        double value = x.get(0, splitAttr);

        if(value < splitValue){
            ITree left = iTreeBranch.getLeft();
            return pathLengh(x,left,path_length + 1);
        }
        else {
            ITree right = iTreeBranch.getRight();
            return pathLengh(x,right,path_length + 1);
        }

    }

    public double getHi(int i){
        double constantValue = 0.5772156649;
        return Math.log(i) + constantValue;
    }

    public double cost(int n){
        double hi = getHi(n-1);
        if(n <= 1){
            return 1.0;
        }
        double cost = 2 * hi - 2*(n-1)/n;
        return cost;
    }

    public double getAccurate(String filepath) throws IOException {
        BufferedReader reader = new BufferedReader(new FileReader(filepath));
        String line = null;
        List<String> lists = new ArrayList<String>();
        while ((line = reader.readLine()) != null){
            lists.add(line);
        }

        int cols = lists.get(0).split(",").length-1;


        List<DenseMatrix64F> testData = new ArrayList<DenseMatrix64F>();
        List<Double> ys = new ArrayList<Double>();

        for (int i = 0;i< lists.size();i++){
            String[] strings = lists.get(i).split(",");
            DenseMatrix64F denseMatrix64F = new DenseMatrix64F(1, cols);
            for (int j = 0;j < cols;j++){
                denseMatrix64F.set(0,j,Double.parseDouble(strings[j]));
            }
            testData.add(denseMatrix64F);
            ys.add(Double.parseDouble(strings[5]));
        }

        double count = 0.0;
        for (int i = 0; i < testData.size();i++){
            double predict = predict(testData.get(i));
            if(predict == ys.get(i)){
                count += 1.0;
            }
        }

        return count / ys.size();
    }
}

public class IsoForest {
    public DenseMatrix64F loadFile(String filepath) throws IOException {
        BufferedReader reader = new BufferedReader(new FileReader(filepath));
        String line =  null;

        List<String> lines = new ArrayList<String>();
        while ((line = reader.readLine()) != null){
            lines.add(line);
        }

        int col = lines.get(0).split(",").length - 1;
        DenseMatrix64F data = new DenseMatrix64F(lines.size(),col);

        for (int i = 0;i < lines.size(); i++){
            String[] strings = lines.get(i).split(",");
            for (int j = 0;j < col;j++){
                data.set(i,j,Double.parseDouble(strings[j]));
            }
        }

        return data;
    }

    public DenseMatrix64F getSubSample(DenseMatrix64F dataSet,int subSampleCount){
        int features = dataSet.numCols;
        DenseMatrix64F subSample = new DenseMatrix64F(subSampleCount,features);

        for (int i = 0;i < subSampleCount; i++){
            for (int j = 0;j < features;j++){
                subSample.set(i,j,dataSet.get(i,j));
            }
        }

        return subSample;
    }

    public IForest train(String filepath) throws IOException {
        DenseMatrix64F dataSet = loadFile(filepath);
        int rows = dataSet.numRows;

        int maxLength = (int) Math.ceil(bottomChanging(rows,2));
        int numTrees = 10;
        int numFeatures = dataSet.numCols;
        int maxSamples = 256;
        int subSampleSize = Math.min(256,rows);

        List<ITree> iTrees = new ArrayList<ITree>();

        for (int i = 0;i < numTrees;i++){
            DenseMatrix64F subSample = getSubSample(dataSet, subSampleSize);
            ITree iTree = growTree(subSample, maxLength, numFeatures, 0);
            iTrees.add(iTree);
        }

        return new IForest(iTrees,maxSamples);

    }

    public ITree growTree(DenseMatrix64F data,int maxLength,int numFeatures,int currentLength){
        if (currentLength >= maxLength || data.numRows <= 1){
            return new ITreeLeaf(data.numRows);
        }

        Random random = new Random();
        int feature = random.nextInt(numFeatures);
        int rows = data.numRows;
        int randomRow = random.nextInt(rows);
        double splitPoint = data.get(randomRow,feature);

        List<Integer> rightList = new ArrayList<Integer>();
        List<Integer> leftList = new ArrayList<Integer>();
        for(int i = 0; i < rows;i++){
            if(data.get(i,feature) >= splitPoint){
                rightList.add(i);
            }
            else {
                leftList.add(i);
            }
        }

        DenseMatrix64F left = new DenseMatrix64F(leftList.size(), numFeatures);
        DenseMatrix64F right = new DenseMatrix64F(rightList.size(), numFeatures);

        for (int i = 0; i < leftList.size();i++){
            for(int j = 0;j < numFeatures;j++){
                left.set(i,j,data.get(i,j));
            }
        }

        for (int i = 0; i < rightList.size();i++){
            for(int j = 0;j < numFeatures;j++){
                right.set(i,j,data.get(i,j));
            }
        }

        return new ITreeBranch(growTree(left,maxLength,numFeatures,currentLength+1),growTree(right,maxLength,numFeatures,currentLength+1),
                splitPoint,feature);

    }

    public double bottomChanging(int x,int bottom){
        double log = Math.log10(x) / Math.log10(bottom);
        return log;
    }

    public static void main(String[] args) throws IOException {
        int count = 0;
        long start = System.currentTimeMillis();
        while (count < 20){
            String filepath = "C:\\Users\\dell\\Desktop\\waterData\\trainForIsoForest.txt";
            IsoForest isoForest = new IsoForest();
            IForest forest = isoForest.train(filepath);


            String testPath = "C:\\Users\\dell\\Desktop\\waterData\\testForIsoForest.txt";
            double accurate = forest.getAccurate(testPath);
            System.out.println("accurate is " + accurate);
            count++;
        }

        long elapse = System.currentTimeMillis() - start;
        System.out.println("花费时间" + elapse / 1000.0 + "s");

    }
}

[结论]

基于相同数据使用自己编写的SVM进行测试,SVM代码见https://blog.csdn.net/qq_34661106/article/details/103371568,结果如下图:

相比于孤立森林,svm的准确率波动较大(没有使用KKT条件作为停机条件和选择第一个乘子时随机挑选的原因),耗时长,但是准确率较高,最高能达到96.9%,对于孤立森林,直接影响其准确率的是异常得分的取值和树的数量,下图为从0.4-0.7的范围内选择异常得分,500颗树的结果,可以看到准确率有了明显的提升。而对于svm,直接影响其准确率的是核函数的选用。

 

 

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 11
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值