机器学习之logistic回归算法的java实现

package logistc;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;

public class Mian {

	public static void main(String[] args) throws IOException {
		String str=null;
		ArrayList<ArrayList<Double>> datas=new ArrayList<ArrayList<Double>>();
		ArrayList<ArrayList<Double>> test=new ArrayList<ArrayList<Double>>();
		try {
		    //读取训练集数据训练参数向量
			FileInputStream fis = new FileInputStream("C:\\Users\\zfw\\Desktop\\java项目\\datas.txt"); 
	        InputStreamReader isr = new InputStreamReader(fis, "UTF-8"); 
	        BufferedReader br = new BufferedReader(isr); 
	        while((str=br.readLine())!=null) {
	        	String[] strs=str.split(",");
	        	ArrayList<Double> array=new ArrayList<Double>();
	        	array.add(1.0);
	        	for(int i=0;i<strs.length;i++) {
	        		array.add(Double.parseDouble(strs[i]));
	        		//System.out.println(strs[i]);
	        	}
	        	datas.add(array);
	        }
	        br.close();		
	        FileInputStream fis1 = new FileInputStream("C:\\Users\\zfw\\Desktop\\java项目\\test.txt"); 
	        InputStreamReader isr1 = new InputStreamReader(fis1, "UTF-8"); 
	        BufferedReader br1 = new BufferedReader(isr1); 
	        while((str=br1.readLine())!=null) {
	        	String[] strs=str.split(",");
	        	ArrayList<Double> array=new ArrayList<Double>();
	        	for(int i=0;i<strs.length;i++) {
	        		array.add(Double.parseDouble(strs[i]));
	        		//System.out.println(strs[i]);
	        	}
	        	test.add(array);
	        }
	        br1.close();	
			}
		catch(IOException ioe) {
			System.out.println("错误!"+ioe);
		}
		
		Logistic l=new Logistic(datas,test);
		
		l.print();
		l.predect(test);
	}

}


package logistc;
import java.util.ArrayList;
public class Logistic {
	private ArrayList<ArrayList<Double>> datas=new ArrayList<ArrayList<Double>>();//训练集
	private double alph=0.001;
	private Double[] b;//参数向量
	public Logistic(ArrayList<ArrayList<Double>> datas,ArrayList<ArrayList<Double>> test){
		this.datas=datas;
		init(datas);
	}
	public void init(ArrayList<ArrayList<Double>> datas){//初始化参数向量
		b=new Double[this.datas.get(0).size()-1];
		System.out.println(b.length);
		for(int i=0;i<b.length;i++) {
			b[i]=1.0;
		}
	}
	public double h_theta_x_i(int j) {//预测分类函数
		double c=1.0;
		for(int i=1;i<this.b.length;i++) {
			c+=this.b[i]*this.datas.get(j).get(i);
		}
		return 1.0/(1+Math.exp(0.0-c));
	}
	public double compute_partial_derivative_for_theta(int j) {//求thetaj的偏导
		double sum=0.0;
		for(int  i=0;i<this.datas.size();i++) {
			sum+=(datas.get(i).get(datas.get(0).size()-1)-h_theta_x_i(i))*datas.get(i).get(j);
		}	
		return sum;
	}
	public void compute_theta() {	//迭代求theta	
		for(int i=1;i<b.length;i++) {
			b[i]+=this.alph*compute_partial_derivative_for_theta(i);
		}
	}
	public void print() {
		int a=1000000;
		while(a>0) {
		a--;
		compute_theta();
		System.out.print(a+"theta:");
		for(int i=0;i<b.length;i++) {
			System.out.print(b[i]+"\t");			
		}
		System.out.println();
		}
	}
	
	
	public void predect(ArrayList<ArrayList<Double>> test) {
		int count=0;
		double sum=0.0;
		for(int i=0;i<test.size();i++) {
			for(int j=0;j<test.get(0).size()-1;j++) {
			sum+=this.b[j+1]*test.get(i).get(j);	
			}
			if((1.0/(1+Math.exp(0.0-sum)))>0.5) {
				System.out.print(1);
				if(test.get(i).get((test.get(i).size()-1))==1.0)
				count++;
			}
			else {
				System.out.print(0);
				if(test.get(i).get((test.get(i).size()-1))==0.0)
					count++;
			}
		}
		System.out.println("正确率为:"+(double)count/test.size()*100+"%");
	}
}


评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值