理论非常简单,不作赘述。
[结果]
score > 0.6和10颗树情况下,
score > 0.51和10颗树情况下
[数据]
链接:https://pan.baidu.com/s/1KW-g-mg00UzhYvtXe1vM7w
提取码:q6t6
复制这段内容后打开百度网盘手机App,操作更方便哦
[代码]
package IsoForest;
import org.ejml.data.DenseMatrix64F;
import java.io.*;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Random;
class ITree{
}
class ITreeBranch extends ITree{
ITree left;
ITree right;
double splitValue;
int splitAttr;
public ITreeBranch(ITree left,ITree right,double splitValue,int splitAttr){
this.left = left;
this.right = right;
this.splitValue = splitValue;
this.splitAttr = splitAttr;
}
public ITree getLeft() {
return left;
}
public void setLeft(ITree left) {
this.left = left;
}
public ITree getRight() {
return right;
}
public void setRight(ITree right) {
this.right = right;
}
public double getSplitValue() {
return splitValue;
}
public void setSplitValue(double splitValue) {
this.splitValue = splitValue;
}
public int getSplitAttr() {
return splitAttr;
}
public void setSplitAttr(int splitAttr) {
this.splitAttr = splitAttr;
}
}
class ITreeLeaf extends ITree{
int size;
public ITreeLeaf(int size){
this.size = size;
}
public int getSize() {
return size;
}
public void setSize(int size) {
this.size = size;
}
}
class IForest{
List<ITree> iTrees;
int maxSamples;
public IForest(List<ITree> iTrees, int maxSamples) {
this.iTrees = iTrees;
this.maxSamples = maxSamples;
}
public double predict(DenseMatrix64F x){
if(iTrees.size() == 0 || iTrees == null){
throw new IllegalArgumentException("请训练后再预测");
}
double sum = 0;
for(int i = 0;i < iTrees.size();i++){
sum += pathLengh(x,iTrees.get(i),0);
}
double exponent = -(sum/iTrees.size())/cost(maxSamples);
double score = Math.pow(2,exponent);
if(score > 0.6){
return -1;
}
else {
return 1;
}
}
public double pathLengh(DenseMatrix64F x,ITree tree,int path_length){
String simpleName = tree.getClass().getSimpleName();
if(simpleName.equals("ITreeLeaf")){
ITreeLeaf leaf = (ITreeLeaf) tree;
int size = leaf.getSize();
return path_length + cost(size);
}
ITreeBranch iTreeBranch = (ITreeBranch)tree;
int splitAttr = iTreeBranch.getSplitAttr();
double splitValue = iTreeBranch.getSplitValue();
double value = x.get(0, splitAttr);
if(value < splitValue){
ITree left = iTreeBranch.getLeft();
return pathLengh(x,left,path_length + 1);
}
else {
ITree right = iTreeBranch.getRight();
return pathLengh(x,right,path_length + 1);
}
}
public double getHi(int i){
double constantValue = 0.5772156649;
return Math.log(i) + constantValue;
}
public double cost(int n){
double hi = getHi(n-1);
if(n <= 1){
return 1.0;
}
double cost = 2 * hi - 2*(n-1)/n;
return cost;
}
public double getAccurate(String filepath) throws IOException {
BufferedReader reader = new BufferedReader(new FileReader(filepath));
String line = null;
List<String> lists = new ArrayList<String>();
while ((line = reader.readLine()) != null){
lists.add(line);
}
int cols = lists.get(0).split(",").length-1;
List<DenseMatrix64F> testData = new ArrayList<DenseMatrix64F>();
List<Double> ys = new ArrayList<Double>();
for (int i = 0;i< lists.size();i++){
String[] strings = lists.get(i).split(",");
DenseMatrix64F denseMatrix64F = new DenseMatrix64F(1, cols);
for (int j = 0;j < cols;j++){
denseMatrix64F.set(0,j,Double.parseDouble(strings[j]));
}
testData.add(denseMatrix64F);
ys.add(Double.parseDouble(strings[5]));
}
double count = 0.0;
for (int i = 0; i < testData.size();i++){
double predict = predict(testData.get(i));
if(predict == ys.get(i)){
count += 1.0;
}
}
return count / ys.size();
}
}
public class IsoForest {
public DenseMatrix64F loadFile(String filepath) throws IOException {
BufferedReader reader = new BufferedReader(new FileReader(filepath));
String line = null;
List<String> lines = new ArrayList<String>();
while ((line = reader.readLine()) != null){
lines.add(line);
}
int col = lines.get(0).split(",").length - 1;
DenseMatrix64F data = new DenseMatrix64F(lines.size(),col);
for (int i = 0;i < lines.size(); i++){
String[] strings = lines.get(i).split(",");
for (int j = 0;j < col;j++){
data.set(i,j,Double.parseDouble(strings[j]));
}
}
return data;
}
public DenseMatrix64F getSubSample(DenseMatrix64F dataSet,int subSampleCount){
int features = dataSet.numCols;
DenseMatrix64F subSample = new DenseMatrix64F(subSampleCount,features);
for (int i = 0;i < subSampleCount; i++){
for (int j = 0;j < features;j++){
subSample.set(i,j,dataSet.get(i,j));
}
}
return subSample;
}
public IForest train(String filepath) throws IOException {
DenseMatrix64F dataSet = loadFile(filepath);
int rows = dataSet.numRows;
int maxLength = (int) Math.ceil(bottomChanging(rows,2));
int numTrees = 10;
int numFeatures = dataSet.numCols;
int maxSamples = 256;
int subSampleSize = Math.min(256,rows);
List<ITree> iTrees = new ArrayList<ITree>();
for (int i = 0;i < numTrees;i++){
DenseMatrix64F subSample = getSubSample(dataSet, subSampleSize);
ITree iTree = growTree(subSample, maxLength, numFeatures, 0);
iTrees.add(iTree);
}
return new IForest(iTrees,maxSamples);
}
public ITree growTree(DenseMatrix64F data,int maxLength,int numFeatures,int currentLength){
if (currentLength >= maxLength || data.numRows <= 1){
return new ITreeLeaf(data.numRows);
}
Random random = new Random();
int feature = random.nextInt(numFeatures);
int rows = data.numRows;
int randomRow = random.nextInt(rows);
double splitPoint = data.get(randomRow,feature);
List<Integer> rightList = new ArrayList<Integer>();
List<Integer> leftList = new ArrayList<Integer>();
for(int i = 0; i < rows;i++){
if(data.get(i,feature) >= splitPoint){
rightList.add(i);
}
else {
leftList.add(i);
}
}
DenseMatrix64F left = new DenseMatrix64F(leftList.size(), numFeatures);
DenseMatrix64F right = new DenseMatrix64F(rightList.size(), numFeatures);
for (int i = 0; i < leftList.size();i++){
for(int j = 0;j < numFeatures;j++){
left.set(i,j,data.get(i,j));
}
}
for (int i = 0; i < rightList.size();i++){
for(int j = 0;j < numFeatures;j++){
right.set(i,j,data.get(i,j));
}
}
return new ITreeBranch(growTree(left,maxLength,numFeatures,currentLength+1),growTree(right,maxLength,numFeatures,currentLength+1),
splitPoint,feature);
}
public double bottomChanging(int x,int bottom){
double log = Math.log10(x) / Math.log10(bottom);
return log;
}
public static void main(String[] args) throws IOException {
int count = 0;
long start = System.currentTimeMillis();
while (count < 20){
String filepath = "C:\\Users\\dell\\Desktop\\waterData\\trainForIsoForest.txt";
IsoForest isoForest = new IsoForest();
IForest forest = isoForest.train(filepath);
String testPath = "C:\\Users\\dell\\Desktop\\waterData\\testForIsoForest.txt";
double accurate = forest.getAccurate(testPath);
System.out.println("accurate is " + accurate);
count++;
}
long elapse = System.currentTimeMillis() - start;
System.out.println("花费时间" + elapse / 1000.0 + "s");
}
}
[结论]
基于相同数据使用自己编写的SVM进行测试,SVM代码见https://blog.csdn.net/qq_34661106/article/details/103371568,结果如下图:
相比于孤立森林,svm的准确率波动较大(没有使用KKT条件作为停机条件和选择第一个乘子时随机挑选的原因),耗时长,但是准确率较高,最高能达到96.9%,对于孤立森林,直接影响其准确率的是异常得分的取值和树的数量,下图为从0.4-0.7的范围内选择异常得分,500颗树的结果,可以看到准确率有了明显的提升。而对于svm,直接影响其准确率的是核函数的选用。