Logistic Regression Java程序

本文详述了使用Java编程实现逻辑斯蒂回归的过程,包括理论介绍、数据处理、模型训练(批量梯度下降与随机梯度下降)、sigmoid函数的运用、训练误差计算以及预测精度评估。通过不同数据集的测试,发现程序在处理线性可分数据时表现良好,而在线性不可分数据上可能预测效果不佳,计划加入正则项优化。
摘要由CSDN通过智能技术生成

理论介绍
节点定义

package logistic;

public class Instance {
    public int label;
    public double[] x;
    public Instance(){}
    public Instance(int label,double[] x){
        this.label = label;
        this.x = x;
    }
}

读取数据,并在第一维数值设置为1

package logistic;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class ReadData {
   
    public List<Instance> readDataSet(String fileName,boolean isEnd){

        List<Instance> data = new ArrayList<Instance>();
        FileReader fileReader = null;
        try {
            fileReader = new FileReader(fileName);
        } catch (FileNotFoundException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        BufferedReader br = new BufferedReader(fileReader);
        String row = new String();

        try {
            while((row=br.readLine())!=null){
                String[] A = row.split(",");
                double[] x = getDoubleArray(A,isEnd);
                int label = 0;
                if(isEnd)
                    label = Integer.parseInt(A[A.length-1]);
                else
                    label = Integer.parseInt(A[0]);
                data.add(new Instance(label,x));
            }
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        return data;
    }
    public double[] getDoubleArray(String[] A,boolean isEnd){
        double[] X = new double[A.length];
        // label 在第一个位置
        int i = 1;
        int limit = A.length;
        // label 在最后一个位置
        if(isEnd){
            i=0;
            limit = A.length-1;
        }
        int j = 1;
        X[0] = 1.0;// 第一维是偏置,所有x的第一维是 1 
        for(int k=i;k<limit;k++){
            X[j++] = Double.parseDouble(A[k]);
        }
        return X;
    }
}

Logistic.java文件详细分析
定义各个参数

    //学习率
    private double rate;
    // w 权值向量
    private double[] weights;
    // w 上一次迭代值
    private double [] upWeights;
    // deltaW
    private double[] deltaWeights;
    // errSum 训练误差
    private double errSum;
    // 输入向量的维度
    private int size;
    // 迭代最大次数
    private int limit = 500 ;
    // 停止迭代阈值
    private double epsilon = 0.0001;
    // 训练集
    List<Instance> data;
    // 随机数
    Random random = new Random(201606);

定义不同的构造器

// 不同的构造器
    public Logistic(int size){
        this.rate=0.001;
        this.size = size;
        weights = new double[size+1];
        upWeights = new double[size+1];
        deltaWeights = new double[size+1];
        errSum = 0.0;
//      upWeights = weights;
        randomWeights(weights);

    }
    public Logistic(int size,double rate){
        this(size);
        this.rate = rate;
    }
    public Logistic(int size,double rate,int limit){
        this(size,rate);
        this.limit = limit;
    }
    public Logistic(int size,double rate,int limit,double epsilon){
        this(size,rate,limit);
        this.epsilon = epsilon;
    }
    public void loadData(String fileName,boolean isEnd){
        ReadData readData = new ReadData();
        data = readData.readDataSet(fileName, isEnd);
    }

随机初始化

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值