Java应用梯度下降求解线性SVM模型参考代码

下面的代码是参考网上的,直接执行,主要是为了后续进一步掌握SVM原理而发布。

两个基本原理还是要去掌握:SVM原理和梯度下降法。

1)SVM分类器:

支持向量机,因其英文名为support vector machine,故一般简称SVM,通俗来讲,它是一种二类分类模型,其基本模型定义为特征空间上的间隔最大的线性分类器,其学习策略便是间隔最大化,最终可转化为一个凸二次规划问题的求解。

我自己的理解是:特征空间上寻找一个最优平面来分类,这个最优的求解其实就是多维的约束规划问题。


2)梯度下降法:

梯度下降法,就是利用负梯度方向来决定每次迭代的新的搜索方向,使得每次迭代能使待优化的目标函数逐步减小。梯度下降法是2范数下的最速下降法。 最速下降法的一种简单形式是:x(k+1)=x(k)-a*g(k),其中a称为学习速率,可以是较小的常数。g(k)是x(k)的梯度。

我自己的理解是:切线求导数,梯度下降法是求解SVM的一种方法。


在实际文本分类中,怎么求解SVM,应该要根据实际来选择方法,如拉格朗日、对偶、核函数等,如果理解超平面比较复杂的话,可以用二维平面及其点到直线的距离来抽象理解多维度超平面空间的分类。


package sk.svm;

import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.util.StringTokenizer;

//梯度下降法
public class SimpleSvm {
	
	private int exampleNum;
	private int exampleDim;
	private double[] w;
	private double lambda;
	private double lr = 0.001;//0.00001
	private double threshold = 0.001;
	private double cost;
	private double[] grad;
	private double[] yp;
	public SimpleSvm(double paramLambda){
		lambda = paramLambda;		
	}
	
	private void CostAndGrad(double[][] X,double[] y){//梯度求解
		cost =0;
		for(int m=0;m<exampleNum;m++){
			yp[m]=0;
			for(int d=0;d<exampleDim;d++){
				yp[m]+=X[m][d]*w[d];
			}
			
			if(y[m]*yp[m]-1<0){
				cost += (1-y[m]*yp[m]);
			}
			
		}
		
		for(int d=0;d<exampleDim;d++){
			cost += 0.5*lambda*w[d]*w[d];
		}
		

		for(int d=0;d<exampleDim;d++){
			grad[d] = Math.abs(lambda*w[d]);	
			for(int m=0;m<exampleNum;m++){
				if(y[m]*yp[m]-1<0){
					grad[d]-= y[m]*X[m][d];
				}
			}
		}				
	}
	
	private void update(){
		for(int d=0;d<exampleDim;d++){
			w[d] -= lr*grad[d];
		}
	}
	
	public void Train(double[][] X,double[] y,int maxIters){
		exampleNum = X.length;
		if(exampleNum <=0) {
			System.out.println("num of example <=0!");
			return;
		}
		exampleDim = X[0].length;
		w = new double[exampleDim];
		grad = new double[exampleDim];
		yp = new double[exampleNum];
		
		for(int iter=0;iter<maxIters;iter++){
			
			CostAndGrad(X,y);
			System.out.println("cost:"+cost);
			if(cost< threshold){
				break;
			}
			update();	
		}
	}
	private int predict(double[] x){
		double pre=0;
		for(int j=0;j<x.length;j++){
			pre+=x[j]*w[j];
		}
		if(pre >=0)//这个阈值一般位于-1到1
			return 1;
		else return -1;
	}
	
	public void Test(double[][] testX,double[] testY){
		int error=0;
		for(int i=0;i<testX.length;i++){
			if(predict(testX[i]) != testY[i]){
				error++;
			}
		}
		System.out.println("total:"+testX.length);
		System.out.println("error:"+error);
		System.out.println("error rate:"+((double)error/testX.length));
		System.out.println("acc rate:"+((double)(testX.length-error)/testX.length));
	}
	
	
	
	public static void loadData(double[][]X,double[] y,String trainFile) throws IOException{
		
		File file = new File(trainFile);
		RandomAccessFile raf = new RandomAccessFile(file,"r");
		StringTokenizer tokenizer,tokenizer2; 

		int index=0;
		while(true){
			String line = raf.readLine();
			
			if(line == null) break;
			tokenizer = new StringTokenizer(line," ");
			y[index] = Double.parseDouble(tokenizer.nextToken());
			while(tokenizer.hasMoreTokens())
			{
				tokenizer2 = new StringTokenizer(tokenizer.nextToken(),":");
				int k = Integer.parseInt(tokenizer2.nextToken());
				double v = Double.parseDouble(tokenizer2.nextToken());
				X[index][k] = v;				
			}	
			X[index][0] =1;
			index++;		
		}
	}
	
	public static void main(String[] args) throws IOException {
		// TODO Auto-generated method stub
		double[] y = new double[400];
		double[][] X = new double[400][11];
		String trainFile = "D:\\tmp\\train_bc";
		loadData(X,y,trainFile);
		
		
		SimpleSvm svm = new SimpleSvm(0.0001);
		svm.Train(X,y,7000);
		
		double[] test_y = new double[283];
		double[][] test_X = new double[283][11];
		String testFile = "D:\\tmp\\test_bc";
		loadData(test_X,test_y,testFile);
		svm.Test(test_X, test_y);
		
	}
}

源代码和数据集下载:https://github.com/linger2012/simpleSvm

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值