###### Logistic Regression Java程序

package logistic;

public class Instance {
public int label;
public double[] x;
public Instance(){}
public Instance(int label,double[] x){
this.label = label;
this.x = x;
}
}


package logistic;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

List<Instance> data = new ArrayList<Instance>();
try {
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
String row = new String();

try {
String[] A = row.split(",");
double[] x = getDoubleArray(A,isEnd);
int label = 0;
if(isEnd)
label = Integer.parseInt(A[A.length-1]);
else
label = Integer.parseInt(A[0]);
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
return data;
}
public double[] getDoubleArray(String[] A,boolean isEnd){
double[] X = new double[A.length];
// label 在第一个位置
int i = 1;
int limit = A.length;
// label 在最后一个位置
if(isEnd){
i=0;
limit = A.length-1;
}
int j = 1;
X[0] = 1.0;// 第一维是偏置，所有x的第一维是 1
for(int k=i;k<limit;k++){
X[j++] = Double.parseDouble(A[k]);
}
return X;
}
}


Logistic.java文件详细分析

    //学习率
private double rate;
// w 权值向量
private double[] weights;
// w 上一次迭代值
private double [] upWeights;
// deltaW
private double[] deltaWeights;
// errSum 训练误差
private double errSum;
// 输入向量的维度
private int size;
// 迭代最大次数
private int limit = 500 ;
// 停止迭代阈值
private double epsilon = 0.0001;
// 训练集
List<Instance> data;
// 随机数
Random random = new Random(201606);

// 不同的构造器
public Logistic(int size){
this.rate=0.001;
this.size = size;
weights = new double[size+1];
upWeights = new double[size+1];
deltaWeights = new double[size+1];
errSum = 0.0;
//      upWeights = weights;
randomWeights(weights);

}
public Logistic(int size,double rate){
this(size);
this.rate = rate;
}
public Logistic(int size,double rate,int limit){
this(size,rate);
this.limit = limit;
}
public Logistic(int size,double rate,int limit,double epsilon){
this(size,rate,limit);
this.epsilon = epsilon;
}
}


/**
* 随机初始化权值
* @param mat
*/
public void randomWeights(double[] mat){
for(int i=0;i<mat.length;i++){
double w = random.nextDouble();
mat[i] = w>0.5?w:(-w);
}
}

    /**
* 训练模型
* @param fileName
* @param isEnd
*/
public void train(String fileName,boolean isEnd,boolean isGD){
normalization();
if(isGD)
else
}

    /**
* 归一化
*/
public void normalization(){
double[] min = data.get(0).x;
double[] max = data.get(1).x;
for(int i=1;i<data.size();i++){
Instance instance = data.get(i);
double[] x = instance.x;
for(int j=0;j<x.length;j++){
if(x[j]<min[j]){
min[j] = x[j];
}else if(x[j]>max[j]){
max[j] = x[j];
}
}
}

for(int i=0;i<data.size();i++){
Instance instance = data.get(i);
int label = instance.label;
double[] x = instance.x;
for(int j=0;j<x.length;j++){
if(max[j] == min[j])
continue;
x[j] = (x[j] - min[j])/(max[j] - min[j]);
}
data.set(i, new Instance(label,x));
}
}

/**
* 预测 返回预测标签
* @param x
* @return
*/
public int predict(double[] x){
double wtx = wTx(x);
double logit = sigmoid(wtx);
if(logit>0.5)
return 1;
else
return 0;
}

    /**
* 返回预测精度
* @param testData
* @return
*/
public double accuracyRate(List<Instance> testData){
int count = 0;
for(int i=0;i<testData.size();i++){
Instance instance = testData.get(i);
double[] x = instance.x;
int label = instance.label;
int pred=predict(x);
if(label == pred)
count++;
}
return count*1.0/testData.size();
}

    /**
* 返回权值
* @return
*/
public double[] getWeights(){
return weights;
}

w$w$的更新规则

sigmoid函数：

    /**
* sigmoid 函数
* @param z
* @return
*/
public double sigmoid(double z){
//      if(z>10)
//          return 1.0;
//      if(z<-10)
//          return 0.0;
return 1.0/(1+Math.pow(Math.E,-z));
}

    public double wTx(double[] x){
double wtx = 0.0;
for(int i=0;i<weights.length;i++){
wtx += weights[i] * x[i];
}
return wtx;
}

    /**
* 批量梯度下降法 计算训练误差
*/
public void calculateErrSum(){
errSum = 0.0;
for(int i=0;i<data.size();i++){
Instance instance = data.get(i);
double label = instance.label;
double[] x = instance.x;
double wtx = wTx(x);
double logit = sigmoid(wtx);
errSum += (label - logit);
}

}
/**
* 随机梯度下降法 计算训练误差
*/
public void calculateErrSum(int id){
errSum = 0.0;
Instance instance = data.get(id);
double real = instance.label;
double[] x = instance.x;
double wtx = wTx(x);
double predict = sigmoid(wtx);
errSum = (real - predict);
}

    /**
* 批量梯度下降法更新 deltaW
*/
public void calculateDeltaWeights(){
int m = data.size();
for(int i=0;i<deltaWeights.length;i++){
deltaWeights[i] = 0.0;
for(int j=0;j<data.size();j++){
Instance instance = data.get(j);
double[] x = instance.x;
deltaWeights[i] += x[i];
}
deltaWeights[i] = deltaWeights[i] * errSum/m;
}
}
/**
* 随机梯度下降法
*/
public void stocCalculateDeltaWeights(){
int m = data.size();
for(int i=0;i<deltaWeights.length;i++){
deltaWeights[i] = 0.0;
for(int j=0;j<data.size();j++){
Instance instance = data.get(j);
double[] x = instance.x;
deltaWeights[i] += x[i];
}
deltaWeights[i] = deltaWeights[i] * errSum;
}
}

Δw$\Delta w$就是负梯度的方向，在上面的计算公式中，sumxi$sumx_i$在每个梯度计算的时候都需要的，所以errSum$errSum$在一定程度上体现了梯度的性质《多维数据用一维数据表示输出方面，肉眼可观察》，下面程序输出可以发现errSum$errSum$在减小，同时停止迭代的条件也是根据errSum$errSum$很小的时候停止。

    /**
* 更新weights
*/
public void updateWeights(){
upWeights = weights;
for(int i=0;i<weights.length;i++){
weights[i] = weights[i] + rate*deltaWeights[i];
}
}

    /**
* errSum，很小的时候停止迭代
* @return
*/
public boolean judge(){
return Math.abs(errSum)<epsilon;
}

1.计算训练误差
2.计算Δw$\Delta w$
3.更新w$w$

    /**
* 批量梯度下降法
*/
for(int i=0;i<limit;i++){
calculateErrSum();
System.out.println("errSum:"+errSum);
calculateDeltaWeights();
updateWeights();
if(i==limit*0.9){
System.out.println("到达最大迭代次数的90%");
}
if(i==limit-1){
System.out.println("到达最大迭代次数");
}
//          printArray(weights);
//          printArray(upWeights);
if(judge()){
break;
}
}
}
/**
* 随机梯度下降法
*
* @param x
* @return
*/
int nextId = random.nextInt(data.size());
for(int i=0;i<limit;i++){
nextId = random.nextInt(data.size());
calculateErrSum(nextId);
System.out.println("errSum:"+errSum);
stocCalculateDeltaWeights();
if(i==limit*0.9){
System.out.println("到达最大迭代次数的90%");
}
if(i==limit-1){
System.out.println("到达最大迭代次数");
}
//          printArray(weights);
//          printArray(upWeights);
if(judge()){
break;
}
}
}

Test类

package logistic;

import java.util.List;

public class Test {

/**
* @param args
*/
public static void main(String[] args) {
// TODO Auto-generated method stub
//      String fileName = "logisticaps.csv";
//      boolean isEnd = true;
//      int size = 12;
String fileName = "logistic1.csv";
boolean isEnd = true;
int size = 2;
//      String fileName = "binary.csv"; logisticaps
//      boolean isEnd = false;
//      int size = 3;
double rate = 0.001;
int limit = 9000;
boolean isGD = true;
Logistic logistic = new Logistic(size,rate,limit);
logistic.train(fileName, isEnd,isGD);
double acc = logistic.accuracyRate(data);
System.out.println("正确率："+acc);
double w[] = logistic.getWeights();
for(int i=0;i<w.length;i++){
System.out.println(w[i]+"\t");
}
}

}


logistic1是《机器学习实战》第五章逻辑斯蒂回归中的数据进行测试准确率90%，自己测试自己
w$w$的值

0.56644,-0.046765,-0.06700

《机器学习实战》中的结果
w$w$的值

4.12,   0.48,  -0.6168

weka包中的LR回归预测
w$w$的值

-14.7521,-1.2536,2.0027

#### logistic回归算法Java实现

2016年04月05日 4KB 下载

#### Logistic代码实战

2018-03-24 21:50:27

#### 机器学习算法之Logistic Regression算法实现

2016-10-07 23:10:15

#### 史上最直白的logistic regression教程 之 一

2015-11-17 15:11:38

#### 【JAVA实现】用Logistic回归进行分类

2015-03-14 22:54:16

#### logistic映射代码（MATLAB）

2009年12月28日 381B 下载

#### 机器学习之logistic回归算法的java实现

2017-11-09 19:46:49

#### 机器学习实战逻辑回归的java实现

2015-10-24 21:49:05

#### Java集成Weka做逻辑回归（Logistic Regression）

2016-07-13 00:07:03

#### 逻辑回归的相关问题及java实现

2014-06-30 23:35:17

## 不良信息举报

Logistic Regression Java程序