package logistc;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
public class Mian {
public static void main(String[] args) throws IOException {
String str=null;
ArrayList<ArrayList<Double>> datas=new ArrayList<ArrayList<Double>>();
ArrayList<ArrayList<Double>> test=new ArrayList<ArrayList<Double>>();
try {
//读取训练集数据训练参数向量
FileInputStream fis = new FileInputStream("C:\\Users\\zfw\\Desktop\\java项目\\datas.txt");
InputStreamReader isr = new InputStreamReader(fis, "UTF-8");
BufferedReader br = new BufferedReader(isr);
while((str=br.readLine())!=null) {
String[] strs=str.split(",");
ArrayList<Double> array=new ArrayList<Double>();
array.add(1.0);
for(int i=0;i<strs.length;i++) {
array.add(Double.parseDouble(strs[i]));
//System.out.println(strs[i]);
}
datas.add(array);
}
br.close();
FileInputStream fis1 = new FileInputStream("C:\\Users\\zfw\\Desktop\\java项目\\test.txt");
InputStreamReader isr1 = new InputStreamReader(fis1, "UTF-8");
BufferedReader br1 = new BufferedReader(isr1);
while((str=br1.readLine())!=null) {
String[] strs=str.split(",");
ArrayList<Double> array=new ArrayList<Double>();
for(int i=0;i<strs.length;i++) {
array.add(Double.parseDouble(strs[i]));
//System.out.println(strs[i]);
}
test.add(array);
}
br1.close();
}
catch(IOException ioe) {
System.out.println("错误!"+ioe);
}
Logistic l=new Logistic(datas,test);
l.print();
l.predect(test);
}
}
package logistc;
import java.util.ArrayList;
public class Logistic {
private ArrayList<ArrayList<Double>> datas=new ArrayList<ArrayList<Double>>();//训练集
private double alph=0.001;
private Double[] b;//参数向量
public Logistic(ArrayList<ArrayList<Double>> datas,ArrayList<ArrayList<Double>> test){
this.datas=datas;
init(datas);
}
public void init(ArrayList<ArrayList<Double>> datas){//初始化参数向量
b=new Double[this.datas.get(0).size()-1];
System.out.println(b.length);
for(int i=0;i<b.length;i++) {
b[i]=1.0;
}
}
public double h_theta_x_i(int j) {//预测分类函数
double c=1.0;
for(int i=1;i<this.b.length;i++) {
c+=this.b[i]*this.datas.get(j).get(i);
}
return 1.0/(1+Math.exp(0.0-c));
}
public double compute_partial_derivative_for_theta(int j) {//求thetaj的偏导
double sum=0.0;
for(int i=0;i<this.datas.size();i++) {
sum+=(datas.get(i).get(datas.get(0).size()-1)-h_theta_x_i(i))*datas.get(i).get(j);
}
return sum;
}
public void compute_theta() { //迭代求theta
for(int i=1;i<b.length;i++) {
b[i]+=this.alph*compute_partial_derivative_for_theta(i);
}
}
public void print() {
int a=1000000;
while(a>0) {
a--;
compute_theta();
System.out.print(a+"theta:");
for(int i=0;i<b.length;i++) {
System.out.print(b[i]+"\t");
}
System.out.println();
}
}
public void predect(ArrayList<ArrayList<Double>> test) {
int count=0;
double sum=0.0;
for(int i=0;i<test.size();i++) {
for(int j=0;j<test.get(0).size()-1;j++) {
sum+=this.b[j+1]*test.get(i).get(j);
}
if((1.0/(1+Math.exp(0.0-sum)))>0.5) {
System.out.print(1);
if(test.get(i).get((test.get(i).size()-1))==1.0)
count++;
}
else {
System.out.print(0);
if(test.get(i).get((test.get(i).size()-1))==0.0)
count++;
}
}
System.out.println("正确率为:"+(double)count/test.size()*100+"%");
}
}