用java实现基于SMO算法的SVM分类器

20 篇文章 0 订阅
18 篇文章 0 订阅

完整工程代码下载如下,带测试数据:

https://download.csdn.net/download/luohualiushui1/10951459

首先大家先了解一下SVM的概念,如下:

支持向量机(Support Vector Machine, SVM)是一类按监督学习(supervised learning)方式对数据进行二元分类(binary classification)的广义线性分类器(generalized linear classifier),其决策边界是对学习样本求解的最大边距超平面(maximum-margin hyperplane)

SVM使用铰链损失函数(hinge loss)计算经验风险(empirical risk)并在求解系统中加入了正则化项以优化结构风险(structural risk),是一个具有稀疏性和稳健性的分类器 。SVM可以通过核方法(kernel method)进行非线性分类,是常见的核学习(kernel learning)方法之一 

然后大家了解一下SMO算法,如下:

SMO算法的中心思想就是每次选出两个alpha进行优化(之所以是两个是因为alpha的约束条件决定了其与标签乘积的累加等于0,因此必须一次同时优化两个,否则就会破坏约束条件),然后固定其他的alpha值。重复此过程,直到达到某个终止条件程序退出并得到我们需要的优化结果。

SMO算法是将大优化问题分解为多个小优化问题来求解的。这些小优化问题往往很容易求解,并且对它们进行顺序求解的结果与将它们作为整体来求解的结果是完全一致的。在结果完全相同的同时,SMO算法的求解时间短很多。SMO算法的目标是求出一系列alpha和b一旦求出了这些alpha, 就很容易计算出权重向量w并得到分隔超平面。SMO算法的工作原理是:每次循环中选择两个alpha进行优化处理。一旦找到一对合适的alpha,那么就增大其中一个同时减小另一个。这里所谓的“合适” 就是指两个alpha必须要符合一定的条件,条件之一就是这两个alpha必须要在间隔边界之外,而其第二个条件则是这两个alpha还没有进行过区间化处理或者不在边界上。

我们先看看资料中实现SMO算法的python代码,如下:

#产生随机数
def selectJrand(i,m):
    j=i #we want to select any J not equal to i
    while (j==i):
        j = int(random.uniform(0,m))
    return j

#控制数据范围
def clipAlpha(aj,H,L):
    if aj > H:
        aj = H
    if L > aj:
        aj = L
    return aj

#关键计算模块
def kernelTrans(X, A, kTup): #calc the kernel or transform data to a higher dimensional space
    m,n = shape(X)
    K = mat(zeros((m,1)))
    if kTup[0]=='lin': K = X * A.T   #linear kernel
    elif kTup[0]=='rbf':
        for j in range(m):
            deltaRow = X[j,:] - A
            K[j] = deltaRow*deltaRow.T
        K = exp(K/(-1*kTup[1]**2)) #divide in NumPy is element-wise not matrix like Matlab
    else: raise NameError('Houston We Have a Problem -- That Kernel is not recognized')
    return K

#缓存类
class optStruct:
    def __init__(self, dataMatIn, classLabels, C, toler, kTup):  # Initialize the structure with the parameters
        self.X = dataMatIn
        self.labelMat = classLabels
        self.C = C
        self.tol = toler
        self.m = shape(dataMatIn)[0]
        self.alphas = mat(zeros((self.m, 1)))
        self.b = 0
        self.eCache = mat(zeros((self.m, 2)))  # first column is valid flag
        self.K = mat(zeros((self.m, self.m)))
        for i in range(self.m):
            self.K[:, i] = kernelTrans(self.X, self.X[i, :], kTup)


def calcEk(oS, k):
    fXk = float(multiply(oS.alphas, oS.labelMat).T * oS.K[:,k] + oS.b)
    Ek = fXk - float(oS.labelMat[k])
    return Ek

#选择第二个alphas
def selectJ(i, oS, Ei):  # this is the second choice -heurstic, and calcs Ej
    maxK = -1;
    maxDeltaE = 0;
    Ej = 0
    oS.eCache[i] = [1, Ei]  # set valid #choose the alpha that gives the maximum delta E
    validEcacheList = nonzero(oS.eCache[:, 0].A)[0]
    if (len(validEcacheList)) > 1:
        for k in validEcacheList:  # loop through valid Ecache values and find the one that maximizes delta E
            if k == i: continue  # don't calc for i, waste of time
            Ek = calcEk(oS, k)
            deltaE = abs(Ei - Ek)
            if (deltaE > maxDeltaE):
                maxK = k;
                maxDeltaE = deltaE;
                Ej = Ek
        return maxK, Ej
    else:  # in this case (first time around) we don't have any valid eCache values
        j = selectJrand(i, oS.m)
        Ej = calcEk(oS, j)
    return j, Ej


def updateEk(oS, k):  # after any alpha has changed update the new value in the cache
    Ek = calcEk(oS, k)
    oS.eCache[k] = [1, Ek]

#alphas调整函数
def innerL(i, oS):
    Ei = calcEk(oS, i)

    if ((oS.labelMat[i] * Ei < -oS.tol) and (oS.alphas[i] < oS.C)) or ((oS.labelMat[i] * Ei > oS.tol) and (oS.alphas[i] > 0)):
        j, Ej = selectJ(i, oS, Ei)
        alphaIold = oS.alphas[i].copy();
        alphaJold = oS.alphas[j].copy();
        if (oS.labelMat[i] != oS.labelMat[j]):
            L = max(0, oS.alphas[j] - oS.alphas[i])
            H = min(oS.C, oS.C + oS.alphas[j] - oS.alphas[i])
        else:
            L = max(0, oS.alphas[j] + oS.alphas[i] - oS.C)
            H = min(oS.C, oS.alphas[j] + oS.alphas[i])
        if L == H:
            print("L==H");
            return 0
        eta = 2.0 * oS.K[i, j] - oS.K[i, i] - oS.K[j, j]  # changed for kernel
        if eta >= 0:
            print("eta>=0");
            return 0
        oS.alphas[j] -= oS.labelMat[j] * (Ei - Ej) / eta
        oS.alphas[j] = clipAlpha(oS.alphas[j], H, L)
        updateEk(oS, j)  # added this for the Ecache
        if (abs(oS.alphas[j] - alphaJold) < 0.00001):
            print("j not moving enough");
            return 0
        oS.alphas[i] += oS.labelMat[j] * oS.labelMat[i] * (alphaJold - oS.alphas[j])  # update i by the same amount as j
        updateEk(oS, i)  # added this for the Ecache                    #the update is in the oppostie direction
        b1 = oS.b - Ei - oS.labelMat[i] * (oS.alphas[i] - alphaIold) * oS.K[i, i] - oS.labelMat[j] * (
                    oS.alphas[j] - alphaJold) * oS.K[i, j]
        b2 = oS.b - Ej - oS.labelMat[i] * (oS.alphas[i] - alphaIold) * oS.K[i, j] - oS.labelMat[j] * (
                    oS.alphas[j] - alphaJold) * oS.K[j, j]
        if (0 < oS.alphas[i]) and (oS.C > oS.alphas[i]):
            oS.b = b1
        elif (0 < oS.alphas[j]) and (oS.C > oS.alphas[j]):
            oS.b = b2
        else:
            oS.b = (b1 + b2) / 2.0
        return 1
    else:
        return 0

#数据训练主方法
def smoP(dataMatIn, classLabels, C, toler, maxIter, kTup=('lin', 0)):  # full Platt SMO
    oS = optStruct(mat(dataMatIn), mat(classLabels).transpose(), C, toler, kTup)
    iter = 0
    entireSet = True;
    alphaPairsChanged = 0
    while (iter < maxIter) and ((alphaPairsChanged > 0) or (entireSet)):
        alphaPairsChanged = 0
        if entireSet:  # go over all
            for i in range(oS.m):
                alphaPairsChanged += innerL(i, oS)
                print("fullSet, iter: %d i:%d, pairs changed %d" % (iter, i, alphaPairsChanged))
            iter += 1
        else:  # go over non-bound (railed) alphas
            nonBoundIs = nonzero((oS.alphas.A > 0) * (oS.alphas.A < C))[0]
            for i in nonBoundIs:
                alphaPairsChanged += innerL(i, oS)
                print("non-bound, iter: %d i:%d, pairs changed %d" % (iter, i, alphaPairsChanged))
            iter += 1
        if entireSet:
            entireSet = False  # toggle entire set loop
        elif (alphaPairsChanged == 0):
            entireSet = True
        print("iteration number: %d" % iter)
    return oS.b, oS.alphas

def calcWs(alphas,dataArr,classLabels):
    X = mat(dataArr); labelMat = mat(classLabels).transpose()
    m,n = shape(X)
    w = zeros((n,1))
    for i in range(m):
        w += multiply(alphas[i]*labelMat[i],X[i,:].T)
    return w

现在开始用java实现,首先是定义数据缓存类

package com.algorithm;

import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class OptStruct {

	private DenseMatrix64F X;
	
	private DenseMatrix64F labelMat;
	
	private double C;
	
	private double tol;
	
	private int m;
	
	private DenseMatrix64F alphas;
	
	private double b;
	
	private DenseMatrix64F eCache;
	
	private DenseMatrix64F K;
	
	public OptStruct(DenseMatrix64F dataMatIn,double [] classLabels,double C,double toler,String[] kTup) {
		
		this.X = dataMatIn;
		this.labelMat = new DenseMatrix64F(classLabels.length,1);
		
		for(int i=0;i<classLabels.length;i++)
			this.labelMat.set(i, 0, classLabels[i]);
		
		this.C = C;
		this.tol = toler;
		this.m = dataMatIn.numRows;
		this.alphas = new DenseMatrix64F(dataMatIn.numRows,1);
		this.alphas.zero();
		this.b=0;
		this.eCache = new DenseMatrix64F(dataMatIn.numRows,2);
		this.eCache.zero();
		
		this.K = kernelTrans(dataMatIn,kTup,dataMatIn.numRows);
		
		
	}
	
	private DenseMatrix64F getMatrixRow(DenseMatrix64F dataMatIn,int row) {
		
		DenseMatrix64F rs = new DenseMatrix64F(1,dataMatIn.numCols);
		
		for(int i=0;i<dataMatIn.numCols;i++) {
			rs.set(0, i, dataMatIn.get(row, i));
		}
		
		return rs;
	}
	
	
	private DenseMatrix64F kernelTrans(DenseMatrix64F dataMatIn,String[] kTup,int m) {
		
		DenseMatrix64F rs = new  DenseMatrix64F(m,m);
		rs.zero();
		
		for(int i=0;i<m;i++) {
			
			DenseMatrix64F ki = new  DenseMatrix64F(m,1);
			ki.zero();
			
			if(kTup[0].equals("lin")) {
				
				DenseMatrix64F mrt = new  DenseMatrix64F(dataMatIn.numCols,1);

				CommonOps.transpose(getMatrixRow(dataMatIn,i),mrt);
				
				CommonOps.mult(dataMatIn,mrt, ki);
				
				for(int j=0;j<m;j++) {
					rs.set(j, i, ki.get(j, 0));
				}
				
			}else if(kTup[0].equals("rbf")) {
				for(int j=0;j<m;j++) {
					
					DenseMatrix64F deltaRow = new  DenseMatrix64F(1,m);
					CommonOps.subtract(getMatrixRow(dataMatIn,j),getMatrixRow(dataMatIn,i),deltaRow);
					
					DenseMatrix64F mrt = new  DenseMatrix64F(m,1);
					CommonOps.transpose(deltaRow,mrt);
					
					DenseMatrix64F mrMutRs = new  DenseMatrix64F(1,1);
					
					CommonOps.mult(deltaRow,mrt, mrMutRs);
					
                    rs.set(j, i, Math.exp(mrMutRs.get(0, 0)/(-1*Double.parseDouble(kTup[1])*Double.parseDouble(kTup[1]))));
                    
				}
				
			}else {
				
				System.out.println("Houston We Have a Problem -- That Kernel is not recognized");
			}
			
		}
		
		
		return rs;
		
	}
	
	

	public DenseMatrix64F getX() {
		return X;
	}

	public void setX(DenseMatrix64F x) {
		X = x;
	}



	public DenseMatrix64F getLabelMat() {
		return labelMat;
	}

	public void setLabelMat(DenseMatrix64F labelMat) {
		this.labelMat = labelMat;
	}

	public double getC() {
		return C;
	}

	public void setC(double c) {
		C = c;
	}

	public double getTol() {
		return tol;
	}

	public void setTol(double tol) {
		this.tol = tol;
	}

	public int getM() {
		return m;
	}

	public void setM(int m) {
		this.m = m;
	}

	public DenseMatrix64F getAlphas() {
		return alphas;
	}

	public void setAlphas(DenseMatrix64F alphas) {
		this.alphas = alphas;
	}

	public double getB() {
		return b;
	}

	public void setB(double b) {
		this.b = b;
	}

	public DenseMatrix64F geteCache() {
		return eCache;
	}

	public void seteCache(DenseMatrix64F eCache) {
		this.eCache = eCache;
	}

	public DenseMatrix64F getK() {
		return K;
	}

	public void setK(DenseMatrix64F k) {
		K = k;
	}
	
	
	
	
	
}

然后是各个计算函数

private static double calcEk(int k) {
		
		DenseMatrix64F multRs = new DenseMatrix64F(os.getM(),1);
		
		CommonOps.elementMult(os.getAlphas(), os.getLabelMat(), multRs); 
		
		DenseMatrix64F mrt = new  DenseMatrix64F(1,os.getM());
		
		CommonOps.transpose(multRs,mrt);
		
		DenseMatrix64F mrMutRs = new  DenseMatrix64F(1,1);
		
		CommonOps.mult(mrt,getMatrixCol(os.getK(),k), mrMutRs);

		return mrMutRs.get(0, 0)+os.getB()-os.getLabelMat().get(k, 0);
	}
	
	
	

随机函数

private static int selectJrand(int i,int m) {
		int j=i;
		Random r=new Random();
		
		while(j == i) {
			j = r.nextInt(m);
		}
		
		return j;
	}

选择第二个alphas的函数

	private static SelectRs selectJ(int i,double Ei) {
		
		int maxK = -1;
	    double maxDeltaE = 0;
	    double Ej = 0;
	    
	    os.geteCache().set(i,0, 1);
	    os.geteCache().set(i,1, Ei);
	    
	    Integer[] validEcacheList = nonzero(os.geteCache(),0);
	    
	    SelectRs sr = new SelectRs();

	    if(validEcacheList.length > 1) {
	    	
	    	for(int k=0;k<validEcacheList.length;k++) {
	    		if(validEcacheList[k] == i)
	    			continue;
	    		
	    		double Ek = calcEk(validEcacheList[k]);
	    		
	    		double deltaE = Math.abs(Ei - Ek);
	    		
	    	    if (deltaE >= maxDeltaE) {
	                maxK = validEcacheList[k];
	                maxDeltaE = deltaE;
	    	    }
	    	}
	    	
	    	sr.setJ(maxK);
	    	sr.setEj(Ej);
	    	
	    	
	    }else {
	    	
	    	sr.setJ(selectJrand(i,os.getM()));
	    	sr.setEj(calcEk(sr.getJ()));
	    }
		
	    
	    return sr;
		
	}

调整alphas函数

private static int innerL(int i) {
		
		double Ei = calcEk(i);
		
	    if (((os.getLabelMat().get(i, 0) * Ei < -os.getTol()) && (os.getAlphas().get(i, 0) < os.getC())) 
	    	   || ((os.getLabelMat().get(i, 0)  * Ei > os.getTol()) && (os.getAlphas().get(i, 0) > 0))){
	    	SelectRs sr = selectJ(i,Ei);
  	
	        double alphaIold = os.getAlphas().get(i, 0);

	        double alphaJold = os.getAlphas().get(sr.getJ(), 0);
	        double L,H;
	        if (os.getLabelMat().get(i, 0) != os.getLabelMat().get(sr.getJ(), 0)) {
	            L = Math.max(0, os.getAlphas().get(sr.getJ(), 0) - os.getAlphas().get(i, 0));
	            H = Math.min(os.getC(), os.getC() + os.getAlphas().get(sr.getJ(), 0) - os.getAlphas().get(i, 0));
	        }else{
	            L = Math.max(0, os.getAlphas().get(sr.getJ(), 0) + os.getAlphas().get(i, 0) - os.getC());
	            H = Math.min(os.getC(), os.getAlphas().get(sr.getJ(), 0) + os.getAlphas().get(i, 0));
	        }
	        
	        if(L == H) {
	        	System.out.println("L==H");
	        	return 0;
	        }
	        double eta = 2.0 * os.getK().get(i, sr.getJ()) - os.getK().get(i, i) - os.getK().get(sr.getJ(),sr.getJ());
	        if(eta >= 0) {
	        	System.out.println("eta>=0");
	            return 0;
	        }
	        
	        os.getAlphas().set(sr.getJ(), 0,os.getAlphas().get(sr.getJ(), 0)- os.getLabelMat().get(sr.getJ(), 0)*(Ei-sr.getEj())/eta);
	        os.getAlphas().set(sr.getJ(), 0, clipAlpha(os.getAlphas().get(sr.getJ(), 0), H, L));
	        os.geteCache().set(sr.getJ(), 0, 1);
			os.geteCache().set(sr.getJ(),1, calcEk(sr.getJ()));
	        if (Math.abs(os.getAlphas().get(sr.getJ(), 0) - alphaJold) < 0.00001) {
	        	System.out.println("j not moving enough");
	            return 0;
	        }
	        os.getAlphas().set(i, 0,os.getAlphas().get(i, 0)+os.getLabelMat().get(sr.getJ(), 0) * os.getLabelMat().get(i, 0) * (alphaJold - os.getAlphas().get(sr.getJ(), 0)));
	        os.geteCache().set(sr.getJ(), 0, 1);
			os.geteCache().set(sr.getJ(),1, calcEk(i));
	        double b1 = os.getB() - Ei - os.getLabelMat().get(i, 0) * (os.getAlphas().get(i, 0) - alphaIold) * os.getK().get(i, sr.getJ()) - os.getLabelMat().get(sr.getJ(), 0) * (
	        		os.getAlphas().get(sr.getJ(), 0) - alphaJold) * os.getK().get(i, sr.getJ());
	        double b2 = os.getB() - sr.getEj() - os.getLabelMat().get(i, 0) * (os.getAlphas().get(i, 0) - alphaIold) * os.getK().get(i, sr.getJ()) - os.getLabelMat().get(sr.getJ(), 0) * (
	        		os.getAlphas().get(sr.getJ(), 0) - alphaJold) * os.getK().get(i, sr.getJ());
	        if ((0 < os.getAlphas().get(i, 0)) && (os.getC() > os.getAlphas().get(i, 0)))
	            os.setB(b1);
	        else if ((0 < os.getAlphas().get(sr.getJ(), 0)) && (os.getC() > os.getAlphas().get(sr.getJ(), 0)))
	        	os.setB(b2);
	        else
	        	os.setB((b1 + b2) / 2.0);
	        return 1;
	    }
	    else{
	        return 0;
	    }
		
	}

然后是smo的主函数

public static void smoP(DenseMatrix64F dataMatIn,double[] classLabels,double  C, double toler, int maxIter, String [] kTup) {
		
		os = new OptStruct(dataMatIn,classLabels, C, toler, kTup);
	    int iter = 0;
	    boolean entireSet = true;
	    double alphaPairsChanged = 0;
	    while( (iter < maxIter) && ((alphaPairsChanged > 0) || (entireSet))){
	        alphaPairsChanged = 0;
	        if(entireSet) {
	        	
	        	for(int i=0;i<os.getM();i++) {
	                alphaPairsChanged += innerL(i);
	                System.out.println("fullSet, iter: "+iter+" i:"+i+", pairs changed "+alphaPairsChanged);
	        	}
	            iter += 1;
	        }else{
	        	Integer[] nonBoundIs = between0C(os.getAlphas(),0,0,C);
	        	for(int i=0;i<nonBoundIs.length;i++) {
	        		 alphaPairsChanged += innerL(nonBoundIs[i]);
	        		 System.out.println("non-bound, iter: "+iter+" i:"+nonBoundIs[i]+", pairs changed "+alphaPairsChanged);
	        	}
	            iter += 1;
	        }
	        if(entireSet)
	            entireSet = false;
	        else if (alphaPairsChanged == 0)
	            entireSet = true;
	        System.out.println("iteration number: "+iter);
	    }   	
	}

计算w参数的方法

public static double[] calcWs(DenseMatrix64F alphas,DenseMatrix64F dataArr,double[] classLabels) {
		
		double [] w = new double[dataArr.numCols];
		
		for(int i=0;i<dataArr.numRows;i++) {
			
			double tmp = alphas.get(i, 0)*classLabels[i];
			
			for(int j=0;j<dataArr.numCols;j++) {
				
				w[j] += tmp*dataArr.get(i, j);
			}
		}
		
        return w;
	}

最后是分类器

	public static double classify(double [] intX,double[]  ws,double b) {
		
		double prob = 0;
		
		for(int i=0;i<intX.length;i++) {
			prob += intX[i]*ws[i];
		}
		
		if((prob+b) > 0)
			return 1.0;
		else
			return -1.0;	
		
	}

现在开始测试

		List<String> list = new ArrayList<String>();
        try{
            BufferedReader br = new BufferedReader(new FileReader("D:\\machinelearninginaction-master\\Ch06\\testSet.txt"));
            String s = null;
            while((s = br.readLine())!=null){
            	list.add(s);
            }
            br.close();    
        }catch(Exception e){
            e.printStackTrace();
        }
        
        DenseMatrix64F dataMatIn = new DenseMatrix64F(list.size(),2);
        double[] classLabels = new double[list.size()];
        
        for(int i=0;i<list.size();i++) {
        	
        	String[] items = list.get(i).split("	");
        	dataMatIn.set(i, 0, Double.parseDouble(items[0]));
        	dataMatIn.set(i,1, Double.parseDouble(items[1]));
        	classLabels[i] = Double.parseDouble(items[2]);
        }
        
        smoP(dataMatIn,classLabels,0.6 ,0.001, 40, new String [] {"lin", "0"});
        
        System.out.println(os.getB());

        double[]  ws = calcWs(os.getAlphas(),dataMatIn,classLabels);
        
        System.out.println(ws[0]);
        System.out.println(ws[1]);
        
        for(int i=0;i<dataMatIn.numRows;i++) {
        	
        	double [] intX = new double[dataMatIn.numCols];
        	
        	for(int j=0;j<dataMatIn.numCols;j++)
        		intX[j] = dataMatIn.get(i, j);
        	
        	System.out.println(list.get(i)+"    训练推测分类:"+classify(intX,ws,os.getB()));
        }

运行结果如下:

可以看到分类器分类和数据原来的分类完全吻合,ok搞定

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

路边草随风

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值