关闭

Logistic Regression Java程序

标签: 逻辑斯蒂回归
924人阅读 评论(0) 收藏 举报
分类:

理论介绍
节点定义

package logistic;

public class Instance {
    public int label;
    public double[] x;
    public Instance(){}
    public Instance(int label,double[] x){
        this.label = label;
        this.x = x;
    }
}

读取数据,并在第一维数值设置为1

package logistic;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class ReadData {
    public List<Instance> readDataSet(String fileName,boolean isEnd){

        List<Instance> data = new ArrayList<Instance>();
        FileReader fileReader = null;
        try {
            fileReader = new FileReader(fileName);
        } catch (FileNotFoundException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        BufferedReader br = new BufferedReader(fileReader);
        String row = new String();

        try {
            while((row=br.readLine())!=null){
                String[] A = row.split(",");
                double[] x = getDoubleArray(A,isEnd);
                int label = 0;
                if(isEnd)
                    label = Integer.parseInt(A[A.length-1]);
                else
                    label = Integer.parseInt(A[0]);
                data.add(new Instance(label,x));
            }
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        return data;
    }
    public double[] getDoubleArray(String[] A,boolean isEnd){
        double[] X = new double[A.length];
        // label 在第一个位置
        int i = 1;
        int limit = A.length;
        // label 在最后一个位置
        if(isEnd){
            i=0;
            limit = A.length-1;
        }
        int j = 1;
        X[0] = 1.0;// 第一维是偏置,所有x的第一维是 1 
        for(int k=i;k<limit;k++){
            X[j++] = Double.parseDouble(A[k]);
        }
        return X;
    }
}

Logistic.java文件详细分析
定义各个参数

    //学习率
    private double rate;
    // w 权值向量
    private double[] weights;
    // w 上一次迭代值
    private double [] upWeights;
    // deltaW
    private double[] deltaWeights;
    // errSum 训练误差
    private double errSum;
    // 输入向量的维度
    private int size;
    // 迭代最大次数
    private int limit = 500 ;
    // 停止迭代阈值
    private double epsilon = 0.0001;
    // 训练集
    List<Instance> data;
    // 随机数
    Random random = new Random(201606);

定义不同的构造器

// 不同的构造器
    public Logistic(int size){
        this.rate=0.001;
        this.size = size;
        weights = new double[size+1];
        upWeights = new double[size+1];
        deltaWeights = new double[size+1];
        errSum = 0.0;
//      upWeights = weights;
        randomWeights(weights);

    }
    public Logistic(int size,double rate){
        this(size);
        this.rate = rate;
    }
    public Logistic(int size,double rate,int limit){
        this(size,rate);
        this.limit = limit;
    }
    public Logistic(int size,double rate,int limit,double epsilon){
        this(size,rate,limit);
        this.epsilon = epsilon;
    }
    public void loadData(String fileName,boolean isEnd){
        ReadData readData = new ReadData();
        data = readData.readDataSet(fileName, isEnd);
    }

随机初始化w的权值

/**
     * 随机初始化权值
     * @param mat
     */
    public void randomWeights(double[] mat){
        for(int i=0;i<mat.length;i++){
            double w = random.nextDouble();
            mat[i] = w>0.5?w:(-w);
        }
    }

训练模型:批量梯度下降法和随机梯度下降法
先对数据归一化,防止计算e的指数时候指数过大

    /**
     * 训练模型
     * @param fileName
     * @param isEnd
     */
    public void train(String fileName,boolean isEnd,boolean isGD){
        loadData(fileName,isEnd);
        normalization();
        if(isGD)
            gradientDescent();
        else
            stocGradientDescent();
    }

归一化

    /**
     * 归一化
     */
    public void normalization(){
        double[] min = data.get(0).x;
        double[] max = data.get(1).x;
        for(int i=1;i<data.size();i++){
            Instance instance = data.get(i);
            double[] x = instance.x;
            for(int j=0;j<x.length;j++){
                if(x[j]<min[j]){
                    min[j] = x[j];
                }else if(x[j]>max[j]){
                    max[j] = x[j];
                }
            }
        }

        for(int i=0;i<data.size();i++){
            Instance instance = data.get(i);
            int label = instance.label;
            double[] x = instance.x;
            for(int j=0;j<x.length;j++){
                if(max[j] == min[j])
                    continue;
                x[j] = (x[j] - min[j])/(max[j] - min[j]);
            }
            data.set(i, new Instance(label,x));
        }
    }

测试数据预测

/**
     * 预测 返回预测标签
     * @param x
     * @return
     */
    public int predict(double[] x){
        double wtx = wTx(x);
        double logit = sigmoid(wtx);
        if(logit>0.5)
            return 1;
        else
            return 0;
    }

预测精度

    /**
     * 返回预测精度
     * @param testData
     * @return
     */
    public double accuracyRate(List<Instance> testData){
        int count = 0;
        for(int i=0;i<testData.size();i++){
            Instance instance = testData.get(i);
            double[] x = instance.x;
            int label = instance.label;
            int pred=predict(x);
            if(label == pred)
                count++;
        }
        return count*1.0/testData.size();
    }

获取w

    /**
     * 返回权值
     * @return
     */
    public double[] getWeights(){
        return weights;
    }

w的更新规则
这里写图片描述
对于批量梯度下降法,需要计算所有数据集的训练误差,还需要将所有训练数据集的各个维度数据加起来
sigmoid函数:

    /**
     * sigmoid 函数
     * @param z
     * @return
     */
    public double sigmoid(double z){
//      if(z>10)
//          return 1.0;
//      if(z<-10)
//          return 0.0;
        return 1.0/(1+Math.pow(Math.E,-z));
    }

计算“预测值”:
这里说的预测值是权重于维度的乘积,就是计算上面的π(xi)的过程

    public double wTx(double[] x){
        double wtx = 0.0;
        for(int i=0;i<weights.length;i++){
            wtx += weights[i] * x[i];
        }
        return wtx;
    }

计算训练误差:
批量梯度下降法和随机梯度下降法

    /**
     * 批量梯度下降法 计算训练误差
     */
    public void calculateErrSum(){
        errSum = 0.0;
        for(int i=0;i<data.size();i++){
            Instance instance = data.get(i);
            double label = instance.label;
            double[] x = instance.x;
            double wtx = wTx(x);
            double logit = sigmoid(wtx);
            errSum += (label - logit);
        }

    }
    /**
     * 随机梯度下降法 计算训练误差
     */
    public void calculateErrSum(int id){
        errSum = 0.0;
        Instance instance = data.get(id);
        double real = instance.label;
        double[] x = instance.x;
        double wtx = wTx(x);
        double predict = sigmoid(wtx);
        errSum = (real - predict);
    }

计算上面贴图中的Δw

    /**
     * 批量梯度下降法更新 deltaW
     */
    public void calculateDeltaWeights(){
        int m = data.size();
        for(int i=0;i<deltaWeights.length;i++){
            deltaWeights[i] = 0.0;
            for(int j=0;j<data.size();j++){
                Instance instance = data.get(j);
                double[] x = instance.x;
                deltaWeights[i] += x[i];
            }
            deltaWeights[i] = deltaWeights[i] * errSum/m;
        }
    }
    /**
     * 随机梯度下降法
     */
    public void stocCalculateDeltaWeights(){
        int m = data.size();
        for(int i=0;i<deltaWeights.length;i++){
            deltaWeights[i] = 0.0;
            for(int j=0;j<data.size();j++){
                Instance instance = data.get(j);
                double[] x = instance.x;
                deltaWeights[i] += x[i];
            }
            deltaWeights[i] = deltaWeights[i] * errSum;
        }
    }

Δw就是负梯度的方向,在上面的计算公式中,sumxi在每个梯度计算的时候都需要的,所以errSum在一定程度上体现了梯度的性质《多维数据用一维数据表示输出方面,肉眼可观察》,下面程序输出可以发现errSum在减小,同时停止迭代的条件也是根据errSum很小的时候停止。
更新w

    /**
     * 更新weights
     */
    public void updateWeights(){
        upWeights = weights;
        for(int i=0;i<weights.length;i++){
            weights[i] = weights[i] + rate*deltaWeights[i];
        }
    }

停止迭代条件

    /**
     * errSum,很小的时候停止迭代
     * @return
     */
    public boolean judge(){
        return Math.abs(errSum)<epsilon;
    }

批量梯度与随机梯度下降法
每一次循环步骤:
1.计算训练误差
2.计算Δw
3.更新w

    /**
     * 批量梯度下降法
     */
    public void gradientDescent(){
        for(int i=0;i<limit;i++){
            calculateErrSum();
            System.out.println("errSum:"+errSum);
            calculateDeltaWeights();
            updateWeights();
            if(i==limit*0.9){
                System.out.println("到达最大迭代次数的90%");
            }
            if(i==limit-1){
                System.out.println("到达最大迭代次数");
            }
//          printArray(weights);
//          printArray(upWeights);
            if(judge()){
                break;
            }
        }
    }
    /**
     * 随机梯度下降法
     * 
     * @param x
     * @return
     */
    public void stocGradientDescent(){
        int nextId = random.nextInt(data.size());
        for(int i=0;i<limit;i++){
            nextId = random.nextInt(data.size());
            calculateErrSum(nextId);
            System.out.println("errSum:"+errSum);
            stocCalculateDeltaWeights();
            if(i==limit*0.9){
                System.out.println("到达最大迭代次数的90%");
            }
            if(i==limit-1){
                System.out.println("到达最大迭代次数");
            }
//          printArray(weights);
//          printArray(upWeights);
            if(judge()){
                break;
            }
        }
    }

Test类

package logistic;

import java.util.List;

public class Test {

    /**
     * @param args
     */
    public static void main(String[] args) {
        // TODO Auto-generated method stub
//      String fileName = "logisticaps.csv";
//      boolean isEnd = true;
//      int size = 12;
        String fileName = "logistic1.csv";
        boolean isEnd = true;
        int size = 2;
//      String fileName = "binary.csv"; logisticaps
//      boolean isEnd = false;
//      int size = 3;
        double rate = 0.001;
        int limit = 9000;
        boolean isGD = true;
        Logistic logistic = new Logistic(size,rate,limit);
        logistic.train(fileName, isEnd,isGD);
        ReadData readData = new ReadData();
        List<Instance> data = readData.readDataSet(fileName, isEnd);
        double acc = logistic.accuracyRate(data);
        System.out.println("正确率:"+acc);
        double w[] = logistic.getWeights();
        for(int i=0;i<w.length;i++){
            System.out.println(w[i]+"\t");
        }
    }

}

logistic1是《机器学习实战》第五章逻辑斯蒂回归中的数据进行测试准确率90%,自己测试自己
w的值

0.56644,-0.046765,-0.06700

这里写图片描述
预测错误10个数据点

《机器学习实战》中的结果
w的值

4.12,   0.48,  -0.6168

这里写图片描述
大约错了2-5个
weka包中的LR回归预测
w的值

-14.7521,-1.2536,2.0027

这里写图片描述
预测错了5个
结果还可以

对于用随机梯度下降法,预测效果就比较差了,程序中可能有问题。

用批量梯度下降法对其他数据集进行测试发现效果也很差
这个数据集,自己测试自己正确率70.5%,而用上面的程序全部预测成较多的一类

感觉上面程序有问题,但是不知道错在哪里?
用iris数据集测试,只用到前两类,weka包测试准确率100%,上面程序自测94%

感觉上面程序好像只对线性可分的时候预测效果比较好,线性不可分的时候就预测成多的一类,准备有时间加入正则项试试。

0
0

查看评论
* 以上用户言论只代表其个人观点,不代表CSDN网站的观点或立场
    个人资料
    • 访问:183849次
    • 积分:4604
    • 等级:
    • 排名:第6793名
    • 原创:282篇
    • 转载:3篇
    • 译文:0篇
    • 评论:21条
    博客专栏