梯度上升法求解Logistic回归

对率函数hθ=1/1+e^-z

z=β^T(x;1)

P{yi|xi;θ} = (P{yi=1|xi}^yi)*(P{yi=0|xi}*(1-yi))

极大似然函数为πP{yi|xi;θ}=π(P{yi=1|xi}^yi)*(P{yi=0|xi}*(1-yi))

假定P{yi=1|xi}=hθ(xi),则p{yi=0|xi}=1-hθ(xi)

则πP{yi|xi;θ}=π(P{yi=1|xi}^yi)*(P{yi=0|xi}*(1-yi))=πhθ(xi)^yi * (1-hθ(xi))^(1-yi)

两边同时取对数,则

lnP{yi|xi;θ} = ln[πhθ(xi)^yi * (1-hθ(xi))^(1-yi)]

                =∑[yi*lnhθ(xi) + (1-yi)ln(1-hθ(xi))]

即L(θ) = ∑lnP{yi|xi;θ}=∑[yi*lnhθ(xi) + (1-yi)ln(1-hθ(xi))]

∂L/∂θi = ∑xj(yi-hθ(xi))

θj = θj+α∑xj(yi-hθ(xi))

θj = θj+α∑xj(yi-hθ(xi))为批处理梯度上升法,即θ的每一次更新均需要m个样本的参与

实现如下:

public class Logistic {
	float[] Tag;
    float[][] Var;  
	float[] beta;
	static ChartPanel frame1;
	static XYSeries series1 = new XYSeries("positive");
	static XYSeries series2 = new XYSeries("negative");
	static XYSeries series3 = new XYSeries("result");
	public Logistic() throws IOException{
		BufferedReader br=new BufferedReader(new FileReader("C:\\E\\machinelearningdataset\\dataset3.txt"));
        String line="";
    	List<String> Var=new ArrayList<String>();
    	List<String> Tag=new ArrayList<String>();
        while((line=br.readLine())!=null){
        	      String[] content=line.split(" ");
        	      String tmp="";
        	      for(int i=0;i<content.length-1;i++){
        		  tmp=tmp+" "+content[i];
        	  }
        	  Var.add(tmp.trim());
        	  Tag.add(content[content.length-1]);
        	}
	    	this.Tag=new float[Tag.size()];
	    	this.Var=new float[Var.size()][Var.get(0).split(" ").length+1];  
		    for(int i=0;i<Tag.size();i++){
		     	this.Var[i][0]=1.0f;
		     	this.Var[i][1]=Float.parseFloat(Var.get(i).split(" ")[0]);
			    this.Var[i][2]=Float.parseFloat(Var.get(i).split(" ")[1]);
			    //数据集
//			    System.out.println(this.Var[i][1] + " " + this.Var[i][2]);
			    this.Tag[i]=Float.parseFloat(Tag.get(i));
			    if(this.Tag[i] == 1)
			    	series1.add(this.Var[i][1], this.Var[i][2]);
			    else
			    	series2.add(this.Var[i][1], this.Var[i][2]);
		}
	}
	//求解exp(βx)其中x[0]=1
	public float sum_exp(float x_i[],float beta[]){
		float tmp=0;
		for(int j=0;j<x_i.length;j++){
			tmp=tmp+x_i[j]*beta[j];
		}
		return (float)Math.exp(tmp);
	}
	//求解极大似然函数的偏导
	public float Logistic_D(float x[][],float y[],float beta[],int j){
		float tmp=0;
		for(int i=0;i<x.length;i++){
			tmp=tmp +y[i]*x[i][j] - x[i][j]*sum_exp(x[i],beta)/(1+sum_exp(x[i],beta));
		}
		return tmp;
	}
	/*
	 * 求解beta从0到j
	 * */
	public void Logistic_main(float x[][],float y[],float beta[],float a){
		
		float[] beta_tmp=beta;
		for(int i = 0; i < 5000; i++){
			
			for(int j=0;j<beta.length;j++){
				beta_tmp[j]=beta[j] + Logistic_D(x,y,beta,j)*a;
			}
		}
		beta=beta_tmp;
		this.beta = beta;
	}
	//预测结果打印
	public void Logistic_predict(float x[][],float y[],float beta[]){
		float[] y_predict=new float[y.length];
		for(int i=0;i<y.length;i++){
			y_predict[i]=sum_exp(x[i],this.beta)/(1+sum_exp(x[i],this.beta));
			System.out.println("Actual:"+Tag[i]+"    Predict:"+y_predict[i]);
		}
		return;
	}
	
	public static void main(String[] args) throws IOException {
		Logistic a=new Logistic();
		float[] beta= new float[a.Var[0].length];
		for(int i=0;i<beta.length;i++){beta[i]=0.0f;}
		a.Logistic_main(a.Var,a.Tag,beta,(float) 0.001);
		a.Logistic_predict(a.Var,a.Tag,a.beta);
		double[] x = new double[]{-3,-3,3};
		double[] y = new double[3];
		XYSeriesCollection dataset = new XYSeriesCollection();
		dataset.addSeries(series1);
		dataset.addSeries(series2);
		//记录拟合直线上的点
		for(int i = 0; i < x.length; i++){
			y[i] = (-beta[0]-beta[1]*x[i])/beta[2];
			series3.add(x[i], y[i]);
		}
		dataset.addSeries(series3);
		JFreeChart chart = ChartFactory.createXYLineChart("line", "x", "y", dataset, PlotOrientation.VERTICAL, true, true, true);
		XYPlot plot = chart.getXYPlot();
		XYLineAndShapeRenderer renderer = (XYLineAndShapeRenderer)plot.getRenderer();
		renderer.setBaseShapesVisible(true);
		renderer.setSeriesLinesVisible(0, false);
		renderer.setSeriesLinesVisible(1, false);
		renderer.setSeriesFillPaint(0, Color.BLUE);
		renderer.setSeriesFillPaint(1, Color.RED);
//		renderer.setSeriesStroke(2, new BasicStroke(3));
		renderer.setUseFillPaint(true);
//		renderer.setSeriesLinesVisible(2, false);
		renderer.setSeriesPaint(2, Color.BLACK);
		frame1 = new ChartPanel(chart, true);
		JFrame jFrame = new JFrame("Line Chart");
		jFrame.add(frame1);
		jFrame.setBounds(50, 50, 600, 300);
		jFrame.setVisible(true);
	}
}

拟合直线如图所示:


用一个样本更新参数θ即为随机梯度上升算法

θj = θj+αxj(yi-hθ(xi))

上面部分代码替换为

public void Logistic_main(float x[][],float y[],float beta[],float a){
		
		float[] beta_tmp=beta;
		for(int i = 0; i < 5000; i++){
			for(int k = 0; k < x.length; k++)
			{
				for(int j=0;j<beta.length;j++){
					beta_tmp[j]=beta[j] + Logistic_D(x,y,beta,j,k)*a;
				}
			}
			beta=beta_tmp;
		}
		this.beta = beta;
	}
//求解极大似然函数的偏导:k为第k个样本
	public float Logistic_D(float x[][],float y[],float beta[],int j,int k){
		float tmp=0;
		tmp=y[k]*x[k][j] - x[k][j]*sum_exp(x[k],beta)/(1+sum_exp(x[k],beta));
		return tmp;
	}

拟合直线如图所示


随着迭代次数的增加,参数θ趋于稳定


参考:1.https://blog.csdn.net/pat_datamine/article/details/43272555

         2.《机器学习实战》

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值