逻辑回归算法-CSDN

Logistic Regression算法分析及Java代码

只是从以下几个方面学习,现在还只是学习阶段,转载分析别人的文章
  1. Cost 函数原理,及似然函数相关推导
  2. Sigmoid 函数的原理(为什么用Sigmoid来做这个函数)
  3. 最速下降法原理 (原理分析及其他算法)

遍历一个矩阵:

public static void matrixTraverse(Matrix matrix) {
    //遍历这个矩阵;
        for (int i = 0; i < matrix.getRowDimension(); i++) {
            for (int j = 0; j < matrix.getColumnDimension(); j++) {
                System.out.print(matrix.get(i, j) + "\t");
            }
            System.out.print("\n");
        }
    }

Java读取文件中的数据:

public static Matrix getData(String pathname) {


//        String pathname="E:\\Study\\Python_R\\Python_Books\\Machine Learning in Action\\machinelearninginaction\\Ch05\\testSet.txt";
//        String pathname2="E:\\Documents\\data.txt";
        //把文件读进来
        String line = " ";
        List<double[]> list = new ArrayList();
//        Matrix matrix=new Matrix();
        try {
            InputStream in = new FileInputStream(new File(pathname));
            InputStreamReader inreader = new InputStreamReader(in);
            BufferedReader br = new BufferedReader(inreader);
            while ((line = br.readLine()) != null) {
                String[] tmp = line.split("\t");
//                    System.out.println(tmp[-1]);
                double[] value = new double[3];
                value[0] = 1.0;
                value[1] = Double.parseDouble(tmp[0]);
                value[2] = Double.parseDouble(tmp[1]);
                list.add(value);                                       //把数据放在list中保存;
            }
//                System.out.println("list_size:" + list.size());
            Iterator<double[]> it = list.iterator();
    /*    while (it.hasNext())
        {
            double[] it_next=it.next();
            for(double i:it_next){
                System.out.print(i+"\t");

            }
            System.out.print("\n");
        }*/
            //放在二维数组中
            double[][] data = new double[list.size()][3];
            for (int i = 0; it.hasNext(); i++) {
                double[] tmp = it.next();
                data[i] = tmp;
            }
            //遍历这个二维数组;
      /*  for(int i=0;i<data.length;i++){
            double[] arr=data[i];
            for(int j=0;j<arr.length;j++){
                System.out.print(data[i][j]+"\t");
            }
            System.out.print("\n");
        }*/
            Matrix matrix = new Matrix(data);

            return matrix;

        } catch (FileNotFoundException e) {
            e.printStackTrace();
            return null;
        } catch (IOException e) {
            e.printStackTrace();
            return null;
        }

    }

读取数据文件中的Label:

 public static Matrix getLabel(String pathname) {
    {
//            String pathname="E:\\Study\\Python_R\\Python_Books\\Machine Learning in Action\\machinelearninginaction\\Ch05\\testSet.txt";
//        Matrix matrix=getData(pathname);
//        matrixTraverse(matrix);
        String line = " ";
        List<double[]> list = new ArrayList<double[]>();
        try {
            FileInputStream fi = new FileInputStream(new File(pathname));
            InputStreamReader in = new InputStreamReader(fi);
            BufferedReader br = new BufferedReader(in);
//                System.out.println(br.readLine().getClass());

            //把数据放进list
            while ((line = br.readLine()) != null) {//这具br.readLine原来是这样的节奏;
                String[] tmp = line.split("\t");
                double[] value = new double[1];
                value[0] = Double.parseDouble(tmp[2]);
                list.add(value);
//                line = br.readLine();         //相当于i++,多了;这个错误犯的;

            }
//                System.out.println("list_size:\t" + list.size());
            //把数据放进double[][]
            Iterator<double[]> it = list.iterator();
            double[][] label = new double[list.size()][1];
            for (int i = 0; it.hasNext(); i++) {
                double[] tmp = it.next();
                label[i] = tmp;
            }
            Matrix labelMatrix = new Matrix(label);
            return labelMatrix;
        } catch (IOException e) {
            e.printStackTrace();

            return null;
        }

    }
}    

算法主要部分,梯度下降法:

public static Matrix gradient() {
    String pathname = "E:\\Study\\Python_R\\Python_Books\\Machine Learning in Action\\machinelearninginaction\\Ch05\\testSet.txt";
    Matrix matrixLabel = getLabel(pathname);
    Matrix matrixData = getData(pathname);
//        matrixTraverse(matrixData);
//        matrixTraverse(matrixLabel);

//写最速下降法来尝试下;
    //初始的weight
    double[][] weight = new double[3][1];
    weight[0][0] = 1;
    weight[1][0] = 1;
    weight[2][0] = 1;
    Matrix weightMat = new Matrix(weight);

    Matrix mm;
    Matrix h;
    Matrix e;

    double alpha = 0.001;
    int maxCycles = 500;

//        System.out.println("matrixData.times(weightMat):");
//        matrixTraverse(matrixData.times(weightMat).times(-1));

    for (int i = 1; i < maxCycles; i++) {
        mm = matrixData.times(weightMat);
        h = sigmoid(mm);
        e = matrixLabel.minus(h);
        weightMat = weightMat.plus(matrixData.transpose().times(e).times(alpha));//其实这个我没有太看懂;

    /*    System.out.println("----------------"+i+"--------------------");
        for(int j=0; j<weightMat.getRowDimension();j++){
            System.out.println(weightMat.get(j,0));
        }*/
    }
    return weightMat;
}

博客今天先写到这里~something worng happens…

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值