Java实现简单版SVM

Java实现简单版SVM

最近的图像分类工作要用到latent svm,为了更加深入了解svm,自己动手实现一个简单版的。

        之所以说是简单版,因为没有用到拉格朗日,对偶,核函数等等。而是用最简单的梯度下降法求解。其中的数学原理我参考了http://blog.csdn.net/lifeitengup/article/details/10951655,文中是用matlab实现的svm。


源代码和数据集下载:https://github.com/linger2012/simpleSvm

其中数据集来自于libsvm,我找了其中一个数据集http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/breast-cancer_scale。
将她分成两部分,训练集和测试集,对应于train_bc和test_bc。

其中测试结果如下:



package com.linger.svm;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.util.StringTokenizer;

public class SimpleSvm 
{
	private int exampleNum;
	private int exampleDim;
	private double[] w;
	private double lambda;
	private double lr = 0.001;//0.00001
	private double threshold = 0.001;
	private double cost;
	private double[] grad;
	private double[] yp;
	public SimpleSvm(double paramLambda)
	{

		lambda = paramLambda;	
		
	}
	
	private void CostAndGrad(double[][] X,double[] y)
	{
		cost =0;
		for(int m=0;m<exampleNum;m++)
		{
			yp[m]=0;
			for(int d=0;d<exampleDim;d++)
			{
				yp[m]+=X[m][d]*w[d];
			}
			
			if(y[m]*yp[m]-1<0)
			{
				cost += (1-y[m]*yp[m]);
			}
			
		}
		
		for(int d=0;d<exampleDim;d++)
		{
			cost += 0.5*lambda*w[d]*w[d];
		}
		

		for(int d=0;d<exampleDim;d++)
		{
			grad[d] = Math.abs(lambda*w[d]);	
			for(int m=0;m<exampleNum;m++)
			{
				if(y[m]*yp[m]-1<0)
				{
					grad[d]-= y[m]*X[m][d];
				}
			}
		}				
	}
	
	private void update()
	{
		for(int d=0;d<exampleDim;d++)
		{
			w[d] -= lr*grad[d];
		}
	}
	
	public void Train(double[][] X,double[] y,int maxIters)
	{
		exampleNum = X.length;
		if(exampleNum <=0) 
		{
			System.out.println("num of example <=0!");
			return;
		}
		exampleDim = X[0].length;
		w = new double[exampleDim];
		grad = new double[exampleDim];
		yp = new double[exampleNum];
		
		for(int iter=0;iter<maxIters;iter++)
		{
			
			CostAndGrad(X,y);
			System.out.println("cost:"+cost);
			if(cost< threshold)
			{
				break;
			}
			update();
			
		}
	}
	private int predict(double[] x)
	{
		double pre=0;
		for(int j=0;j<x.length;j++)
		{
			pre+=x[j]*w[j];
		}
		if(pre >=0)//这个阈值一般位于-1到1
			return 1;
		else return -1;
	}
	
	public void Test(double[][] testX,double[] testY)
	{
		int error=0;
		for(int i=0;i<testX.length;i++)
		{
			if(predict(testX[i]) != testY[i])
			{
				error++;
			}
		}
		System.out.println("total:"+testX.length);
		System.out.println("error:"+error);
		System.out.println("error rate:"+((double)error/testX.length));
		System.out.println("acc rate:"+((double)(testX.length-error)/testX.length));
	}
	
	
	
	public static void loadData(double[][]X,double[] y,String trainFile) throws IOException
	{
		
		File file = new File(trainFile);
		RandomAccessFile raf = new RandomAccessFile(file,"r");
		StringTokenizer tokenizer,tokenizer2; 

		int index=0;
		while(true)
		{
			String line = raf.readLine();
			
			if(line == null) break;
			tokenizer = new StringTokenizer(line," ");
			y[index] = Double.parseDouble(tokenizer.nextToken());
			//System.out.println(y[index]);
			while(tokenizer.hasMoreTokens())
			{
				tokenizer2 = new StringTokenizer(tokenizer.nextToken(),":");
				int k = Integer.parseInt(tokenizer2.nextToken());
				double v = Double.parseDouble(tokenizer2.nextToken());
				X[index][k] = v;
				//System.out.println(k);
				//System.out.println(v);				
			}	
			X[index][0] =1;
			index++;		
		}
	}
	
	public static void main(String[] args) throws IOException 
	{
		// TODO Auto-generated method stub
		double[] y = new double[400];
		double[][] X = new double[400][11];
		String trainFile = "E:\\project\\workspace\\Algorithms\\bin\\train_bc";
		loadData(X,y,trainFile);
		
		
		SimpleSvm svm = new SimpleSvm(0.0001);
		svm.Train(X,y,7000);
		
		double[] test_y = new double[283];
		double[][] test_X = new double[283][11];
		String testFile = "E:\\project\\workspace\\Algorithms\\bin\\test_bc";
		loadData(test_X,test_y,testFile);
		svm.Test(test_X, test_y);
		
	}

}



本文作者:linger
本文链接:http://blog.csdn.net/lingerlanlan/article/details/38688539


  • 3
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 12
    评论
以下是一个简单Java实现SVM支持向量机的代码,仅供参考: ```java import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; public class SVM { private static final double ALPHA_TOLERANCE = 0.001; private static final double BIAS_TOLERANCE = 0.1; private static final double SIGMA = 5.0; private static final int MAX_ITERATIONS = 1000; private ArrayList<Double> targets = new ArrayList<>(); private ArrayList<ArrayList<Double>> inputs = new ArrayList<>(); private ArrayList<Double> alphas = new ArrayList<>(); private double bias; public SVM(String filename) { loadData(filename); } public void train() { int iterations = 0; boolean changed = true; while (changed && iterations < MAX_ITERATIONS) { changed = false; iterations++; for (int i = 0; i < inputs.size(); i++) { double error_i = calculateError(i); if ((targets.get(i) * error_i < -ALPHA_TOLERANCE && alphas.get(i) < C()) || (targets.get(i) * error_i > ALPHA_TOLERANCE && alphas.get(i) > 0)) { int j = i; while (j == i) { j = (int) Math.floor(Math.random() * inputs.size()); } double error_j = calculateError(j); double alpha_i_old = alphas.get(i); double alpha_j_old = alphas.get(j); double L, H; if (targets.get(i) != targets.get(j)) { L = Math.max(0, alphas.get(j) - alphas.get(i)); H = Math.min(C(), C() + alphas.get(j) - alphas.get(i)); } else { L = Math.max(0, alphas.get(i) + alphas.get(j) - C()); H = Math.min(C(), alphas.get(i) + alphas.get(j)); } if (L == H) { continue; } double eta = 2 * dot(inputs.get(i), inputs.get(j)) - dot(inputs.get(i), inputs.get(i)) - dot(inputs.get(j), inputs.get(j)); if (eta >= 0) { continue; } alphas.set(j, alphas.get(j) - (targets.get(j) * (error_i - error_j)) / eta); alphas.set(j, Math.min(Math.max(alphas.get(j), L), H)); if (Math.abs(alphas.get(j) - alpha_j_old) < ALPHA_TOLERANCE) { alphas.set(j, alpha_j_old); continue; } alphas.set(i, alphas.get(i) + targets.get(i) * targets.get(j) * (alpha_j_old - alphas.get(j))); double b1 = bias - error_i - targets.get(i) * (alphas.get(i) - alpha_i_old) * dot(inputs.get(i), inputs.get(i)) - targets.get(j) * (alphas.get(j) - alpha_j_old) * dot(inputs.get(i), inputs.get(j)); double b2 = bias - error_j - targets.get(i) * (alphas.get(i) - alpha_i_old) * dot(inputs.get(i), inputs.get(j)) - targets.get(j) * (alphas.get(j) - alpha_j_old) * dot(inputs.get(j), inputs.get(j)); if (0 < alphas.get(i) && alphas.get(i) < C()) { bias = b1; } else if (0 < alphas.get(j) && alphas.get(j) < C()) { bias = b2; } else { bias = (b1 + b2) / 2.0; } changed = true; } } } } public double predict(ArrayList<Double> input) { double sum = 0.0; for (int i = 0; i < inputs.size(); i++) { sum += alphas.get(i) * targets.get(i) * kernel(input, inputs.get(i)); } return sum - bias; } private double calculateError(int i) { return predict(inputs.get(i)) - targets.get(i); } private double C() { return 1.0; } private double kernel(ArrayList<Double> x1, ArrayList<Double> x2) { double sum = 0.0; for (int i = 0; i < x1.size(); i++) { sum += Math.pow(x1.get(i) - x2.get(i), 2); } return Math.exp(-sum / (2.0 * Math.pow(SIGMA, 2))); } private double dot(ArrayList<Double> x1, ArrayList<Double> x2) { double sum = 0.0; for (int i = 0; i < x1.size(); i++) { sum += x1.get(i) * x2.get(i); } return sum; } private void loadData(String filename) { try { BufferedReader br = new BufferedReader(new FileReader(filename)); String line; while ((line = br.readLine()) != null) { String[] parts = line.split(","); targets.add(Double.parseDouble(parts[0])); ArrayList<Double> input = new ArrayList<>(); for (int i = 1; i < parts.length; i++) { input.add(Double.parseDouble(parts[i])); } inputs.add(input); alphas.add(0.0); } br.close(); } catch (IOException e) { e.printStackTrace(); } } } ``` 使用方法: ```java SVM svm = new SVM("data.csv"); svm.train(); ArrayList<Double> input = new ArrayList<>(); input.add(0.5); input.add(0.8); double prediction = svm.predict(input); System.out.println(prediction); ``` 其中,`data.csv`是训练数据文件,每一行表示一个样本,第一列为样本的目标值,后面的列为样本的特征值。在训练过程中,使用了SMO算法进行参数优化。预测时,输入一个样本的特征值,输出其预测的目标值。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值