自己实现的SVM源码

首先是DATA类

import java.awt.print.Printable;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Scanner;

public class Data {
public Map<List<Double>, Integer> getTrainData() {
	Map<List<Double>, Integer> data=new HashMap<List<Double>, Integer>();
	
	try {
		Scanner in=new Scanner(new File("G://download//testSet.txt"));
		while(in.hasNextLine())
		{
			String str =in.nextLine();
			String []strs=str.trim().split("\t");
			List<Double> pointTmp=new ArrayList<>();
			for(int i=0;i<strs.length-1;i++)
				pointTmp.add(Double.parseDouble(strs[i]));
			data.put(pointTmp, Integer.parseInt(strs[strs.length-1]));
		}
	} catch (FileNotFoundException e) {
		// TODO: handle exception
		e.printStackTrace();
	}
	
	return data;
}

public static void main(String[] args)
{
	Data data=new Data();
	data.getTrainData();
}
}

  SVM类:

import java.awt.print.Printable;
import java.io.FileNotFoundException;
import java.io.ObjectInputStream.GetField;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Map.Entry;

public class SVM {
	private List<ArrayList<Double>> trainData;
	private List<Integer> labelTrainData;
	private double sigma;
	private double C;
	private List<Double> alpha;
	private double b;
	private List<Double> E;
	private int N;
	private int dim;
	private double tol;
	private double eta;
	private double eps;
	private double eps2;
	
	public boolean satisfyKkt(int id)
	{
		double ypgx=this.labelTrainData.get(id)*getGx(this.trainData.get(id));//y*g(x)
		if(Math.abs(this.alpha.get(id))<=this.eps)
		{
			if(ypgx-1<-this.tol) return false;
		}
		else if(Math.abs(this.alpha.get(id)-this.C)<=this.eps)
		{
			if(ypgx-1>this.tol) return false;
		}
		else {
			if(Math.abs(ypgx-1)>this.tol) return false;
		}
		return true;
	}
	
	public void updateE() {
		
		for(int i=0;i<this.N;i++)
		{
			double Ei=getGx(this.trainData.get(i))-this.labelTrainData.get(i);
			this.E.set(i, Ei);
		}
	}
	
	public double kernelLinear(List<Double> X,List<Double> Y) {
		//linear kernel function
		int len=Y.size();
		double s=0;
		for(int i=0;i<len;i++)
			s+=X.get(i)*Y.get(i);
		return s;
	}
	
	
	
	public double kernelRBF(List<Double> X,List<Double> Y)
	{
		//gauss kernel function
		
		int len=Y.size();
		double s=0;
		for(int i=0;i<len;i++)
			s+=(X.get(i)-Y.get(i))*(X.get(i)-Y.get(i));
		s=Math.exp(-s/(2*Math.pow(this.sigma, 2)));
		return s;
	}
	
	
	public double getGx(List<Double> X)
	{
		//calculate wx+b value
		double s=0;
		for(int i=0;i<this.N;i++)
		{
			//for debug
			double debug1=kernelRBF(X, this.trainData.get(i));
			double debug2=this.alpha.get(i);
			
			s+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(X, this.trainData.get(i));
		}
		s+=this.b;
		return s;
	}
	
	public int update(int x1,int x2)
	{
		double low=0;
		double high=0;
		if(this.labelTrainData.get(x1)==this.labelTrainData.get(x2))
		{
			low=Math.max(0, this.alpha.get(x1)+this.alpha.get(x2)-this.C);
			high=Math.min(this.C, this.alpha.get(x2)+this.alpha.get(x1));
		}
		else
		{
			low=Math.max(0, this.alpha.get(x2)-this.alpha.get(x1));
			high=Math.min(this.C, this.alpha.get(x2)-this.alpha.get(x1)+this.C);
		}
		double newAlpha2=this.alpha.get(x2)+this.labelTrainData.get(x2)*(this.E.get(x1)-this.E.get(x2))/this.eta;
		double newAlpha1=0;
		
		if(newAlpha2>high) newAlpha2=high;
		else if(newAlpha2<low) newAlpha2=low;
		newAlpha1=this.alpha.get(x1)+this.labelTrainData.get(x1)*this.labelTrainData.get(x2)*(this.alpha.get(x2)-newAlpha2);
		
		if(Math.abs(newAlpha1)<=this.eps)
			newAlpha1=0;
		if(Math.abs(newAlpha2)<=this.eps)
			newAlpha2=0;
		if(Math.abs(newAlpha1-this.C)<=this.eps)
			newAlpha1=this.C;
		if(Math.abs(newAlpha2-this.C)<=this.eps)
			newAlpha2=this.C;
		if(Math.abs(newAlpha1-this.alpha.get(x1))<=this.eps2)
			return 0;
		if(Math.abs(newAlpha2-this.alpha.get(x2))<=this.eps2)
			return 0;
		
		double b1=-this.E.get(x1)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x1))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x1))*(newAlpha2-this.alpha.get(x2))+this.b;
		double b2=-this.E.get(x2)-this.labelTrainData.get(x1)*kernelRBF(this.trainData.get(x1), this.trainData.get(x2))*(newAlpha1-this.alpha.get(x1))-this.labelTrainData.get(x2)*kernelRBF(this.trainData.get(x2), this.trainData.get(x2))*(newAlpha2-this.alpha.get(x2))+this.b;
		
		if(newAlpha1>0&&newAlpha1<this.C)
			this.b=b1;
		else if(newAlpha2>0&&newAlpha2<this.C)
			this.b=b2;
		else
			this.b=(b1+b2)/2;
		
		this.alpha.set(x1,newAlpha1);
		this.alpha.set(x2,newAlpha2);
		updateE();
		return 1;
	}
	public int selectAlpha2(int x1) {
		
		int x2=-1;
		double maxDiff=-1;
		//first select x2 from 0<a<c to max(E(x1)-E(x2))
		
		for(int i=0;i<this.N;++i)
		{
			if(Math.abs(this.alpha.get(i))<=this.eps||Math.abs(this.alpha.get(i)-this.C)<=this.eps) continue;
			double diff=Math.abs(this.E.get(x1)-this.E.get(i));
			if(diff>maxDiff)
			{
				maxDiff=diff;
				x2=i;
			}
		}
		
		//second calculate eta (eta!=0)
		if(x2!=-1)
		{
			this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(x2), this.trainData.get(x2))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(x2));
			if(eta!=0) return x2;
		}
		
		//third if cannot find in the whole train set
		for(int i=0;i<this.N;i++)
		{
			if(i==x1) continue;
			this.eta=kernelRBF(this.trainData.get(x1), this.trainData.get(x1))+kernelRBF(this.trainData.get(i), this.trainData.get(i))-2*kernelRBF(this.trainData.get(x1), this.trainData.get(i));
			if(Math.abs(this.eta)>this.eps) return i;
		}
		return -1;
		
		
	}
	
	public void SMO() {
		//to solve alpha
		int numChanged=0;
		int cnt=0;
		while(true)
		{
			cnt++;
			System.out.println(cnt);
			
			numChanged=0;
			for(int x1=0;x1<this.N;++x1)
			{
				if(Math.abs(this.alpha.get(x1))<=this.eps||Math.abs(this.alpha.get(x1)-this.C)<=this.eps) continue;
				if(!satisfyKkt(x1))
				{
					int x2=selectAlpha2(x1);
					if(x2==-1) continue;
					numChanged+=update(x1, x2);
				}
			}
			if(numChanged==0)
			{
				for(int x1=0;x1<this.N;++x1)
				{
					if(!satisfyKkt(x1))
					{
						int x2=selectAlpha2(x1);
						if(x2==-1) continue;
						update(x1, x2);
						numChanged++;
					}
				}
			}
			if(numChanged==0)
				break;				
		}
	}
	
	public SVM() {
		//load train data
		
		Data data=new Data();
		Map<List<Double>, Integer> Datas=data.getTrainData();
		int totalData=Datas.size();
		this.trainData=new ArrayList<ArrayList<Double>>();
		this.labelTrainData=new ArrayList<Integer>();
		this.alpha=new ArrayList<Double>();
		this.E=new ArrayList<Double>();
		
		int i=0;
		for(Map.Entry<List<Double>, Integer> entry: Datas.entrySet())
		{
			this.trainData.add((ArrayList<Double>) entry.getKey());
			this.labelTrainData.add(entry.getValue());
			this.alpha.add(0.0);
			this.E.add(0.0-this.labelTrainData.get(i));
			i++;
		}
		this.N=this.labelTrainData.size();
		this.dim=this.trainData.get(0).size();
		
		this.sigma=12;//sigma=1
		this.C=0.5;//c=6
		this.b=0.0;
		this.tol=0.001;
		this.eta=0;
		this.eps=0.0000001;
		this.eps2=0.00001;
	}
	
	public double getB() {
		//get b value
		return this.b;
	}
	public double[] getLinearW() {
		double []w=new double[this.N];
		for(int i=0;i<this.N;i++)
		{
			for(int j=0;j<this.dim;j++)
			{
				w[j]+=this.alpha.get(i)*this.labelTrainData.get(i)*this.trainData.get(i).get(j);
			}
		}
		return w;
	}
	
	public int predict(List<Double> x)
	{
		int ans=1;
		double sum=0;
		for(int i=0;i<this.N;i++)
		{
			sum+=this.alpha.get(i)*this.labelTrainData.get(i)*kernelRBF(x, this.trainData.get(i));
		}
		sum+=b;
		if(sum>0)
			ans=1;
		else
			ans=-1;
		
		return ans;
	}
	public static void main(String[] args) throws FileNotFoundException {
		
		SVM s=new SVM();
		s.SMO();
		PrintWriter out=new PrintWriter("g://download//resultpoints.txt");
		for(int i=0;i<s.N;i++)
		{
			out.write((s.trainData.get(i).get(0)).toString());
			out.write("\t");
			out.write((s.trainData.get(i).get(1)).toString());
			out.write("\t");
			out.write(Integer.toString(s.predict(s.trainData.get(i))));
			out.write("\n");
		}
		out.close();
		//if is linear kernel ,we can get w,just like wx+b=0,then we can directly get line fuction
		double w[]=s.getLinearW();
		System.out.println(w[0]+" "+w[1]+" "+s.b+"======");
	}

}

  

用线性核函数实现的SVM的到的分类结果

 画图,是用python代码

from numpy import *  
import matplotlib  
import matplotlib.pyplot as plt  
import numpy as np

with open("g://download/myresult.txt") as f1:
    data=f1.readlines();
    
    plt.figure(figsize=(8, 5), dpi=80)   
    axes = plt.subplot(111)   
    type1_x = []  
    type1_y = []  
    type2_x = []  
    type2_y = [] 
    for line in data:
        x=line.strip().split('\t');
        x1=float(x[0])
        x2=float(x[1])
        x3=int(x[2])
        
        if x3==1:
            type1_x.append(x1)
            type1_y.append(x2)
        else:
            type2_x.append(x1)
            type2_y.append(x2)
        

    type1 = axes.scatter(type1_x, type1_y,s=40, c='red' )   
    type2 = axes.scatter(type2_x, type2_y, s=40, c='green')  
    
    W1 = 0.8148005405344305  
    W2 = -0.27263471796762484  
    B = -3.8392586254518437  
    x = np.linspace(-4,10,200)  
    y = (-W1/W2)*x+(-B/W2)  
    axes.plot(x,y,'b',lw=3)  
   
    plt.xlabel('x1')   
    plt.ylabel('x2')   
    
    axes.legend((type1, type2), ('0', '1'),loc=1)   
    plt.show()  


#0.8148005405344305 -0.27263471796762484 -3.8392586254518437

  用高斯核,当C=6,sigma=1时候

高斯核,当c=0.5,sigma=1时候

 

当C=0.5,sigma=12时候

 

 

说明C的大小和sigma的大小对高斯核影响是很大的

 sigma是高斯核函数的参数

转载于:https://www.cnblogs.com/wuxiangli/p/6275112.html

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值