svm scale java_Java实现简单版SVM | 学步园

package com.linger.svm;

import java.io.File;

import java.io.FileNotFoundException;

import java.io.IOException;

import java.io.RandomAccessFile;

import java.util.StringTokenizer;

public class SimpleSvm

{

private int exampleNum;

private int exampleDim;

private double[] w;

private double lambda;

private double lr = 0.001;//0.00001

private double threshold = 0.001;

private double cost;

private double[] grad;

private double[] yp;

public SimpleSvm(double paramLambda)

{

lambda = paramLambda;

}

private void CostAndGrad(double[][] X,double[] y)

{

cost =0;

for(int m=0;m

{

yp[m]=0;

for(int d=0;d

{

yp[m]+=X[m][d]*w[d];

}

if(y[m]*yp[m]-1<0)

{

cost += (1-y[m]*yp[m]);

}

}

for(int d=0;d

{

cost += 0.5*lambda*w[d]*w[d];

}

for(int d=0;d

{

grad[d] = Math.abs(lambda*w[d]);

for(int m=0;m

{

if(y[m]*yp[m]-1<0)

{

grad[d]-= y[m]*X[m][d];

}

}

}

}

private void update()

{

for(int d=0;d

{

w[d] -= lr*grad[d];

}

}

public void Train(double[][] X,double[] y,int maxIters)

{

exampleNum = X.length;

if(exampleNum <=0)

{

System.out.println("num of example <=0!");

return;

}

exampleDim = X[0].length;

w = new double[exampleDim];

grad = new double[exampleDim];

yp = new double[exampleNum];

for(int iter=0;iter

{

CostAndGrad(X,y);

System.out.println("cost:"+cost);

if(cost< threshold)

{

break;

}

update();

}

}

private int predict(double[] x)

{

double pre=0;

for(int j=0;j

{

pre+=x[j]*w[j];

}

if(pre >=0)//这个阈值一般位于-1到1

return 1;

else return -1;

}

public void Test(double[][] testX,double[] testY)

{

int error=0;

for(int i=0;i

{

if(predict(testX[i]) != testY[i])

{

error++;

}

}

System.out.println("total:"+testX.length);

System.out.println("error:"+error);

System.out.println("error rate:"+((double)error/testX.length));

System.out.println("acc rate:"+((double)(testX.length-error)/testX.length));

}

public static void loadData(double[][]X,double[] y,String trainFile) throws IOException

{

File file = new File(trainFile);

RandomAccessFile raf = new RandomAccessFile(file,"r");

StringTokenizer tokenizer,tokenizer2;

int index=0;

while(true)

{

String line = raf.readLine();

if(line == null) break;

tokenizer = new StringTokenizer(line," ");

y[index] = Double.parseDouble(tokenizer.nextToken());

//System.out.println(y[index]);

while(tokenizer.hasMoreTokens())

{

tokenizer2 = new StringTokenizer(tokenizer.nextToken(),":");

int k = Integer.parseInt(tokenizer2.nextToken());

double v = Double.parseDouble(tokenizer2.nextToken());

X[index][k] = v;

//System.out.println(k);

//System.out.println(v);

}

X[index][0] =1;

index++;

}

}

public static void main(String[] args) throws IOException

{

// TODO Auto-generated method stub

double[] y = new double[400];

double[][] X = new double[400][11];

String trainFile = "E:\\project\\workspace\\Algorithms\\bin\\train_bc";

loadData(X,y,trainFile);

SimpleSvm svm = new SimpleSvm(0.0001);

svm.Train(X,y,7000);

double[] test_y = new double[283];

double[][] test_X = new double[283][11];

String testFile = "E:\\project\\workspace\\Algorithms\\bin\\test_bc";

loadData(test_X,test_y,testFile);

svm.Test(test_X, test_y);

}

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值