Logistic Regression算法分析及Java代码
只是从以下几个方面学习,现在还只是学习阶段,转载分析别人的文章
- Cost 函数原理,及似然函数相关推导
- Sigmoid 函数的原理(为什么用Sigmoid来做这个函数)
- 最速下降法原理 (原理分析及其他算法)
遍历一个矩阵:
public static void matrixTraverse(Matrix matrix) {
//遍历这个矩阵;
for (int i = 0; i < matrix.getRowDimension(); i++) {
for (int j = 0; j < matrix.getColumnDimension(); j++) {
System.out.print(matrix.get(i, j) + "\t");
}
System.out.print("\n");
}
}
Java读取文件中的数据:
public static Matrix getData(String pathname) {
// String pathname="E:\\Study\\Python_R\\Python_Books\\Machine Learning in Action\\machinelearninginaction\\Ch05\\testSet.txt";
// String pathname2="E:\\Documents\\data.txt";
//把文件读进来
String line = " ";
List<double[]> list = new ArrayList();
// Matrix matrix=new Matrix();
try {
InputStream in = new FileInputStream(new File(pathname));
InputStreamReader inreader = new InputStreamReader(in);
BufferedReader br = new BufferedReader(inreader);
while ((line = br.readLine()) != null) {
String[] tmp = line.split("\t");
// System.out.println(tmp[-1]);
double[] value = new double[3];
value[0] = 1.0;
value[1] = Double.parseDouble(tmp[0]);
value[2] = Double.parseDouble(tmp[1]);
list.add(value); //把数据放在list中保存;
}
// System.out.println("list_size:" + list.size());
Iterator<double[]> it = list.iterator();
/* while (it.hasNext())
{
double[] it_next=it.next();
for(double i:it_next){
System.out.print(i+"\t");
}
System.out.print("\n");
}*/
//放在二维数组中
double[][] data = new double[list.size()][3];
for (int i = 0; it.hasNext(); i++) {
double[] tmp = it.next();
data[i] = tmp;
}
//遍历这个二维数组;
/* for(int i=0;i<data.length;i++){
double[] arr=data[i];
for(int j=0;j<arr.length;j++){
System.out.print(data[i][j]+"\t");
}
System.out.print("\n");
}*/
Matrix matrix = new Matrix(data);
return matrix;
} catch (FileNotFoundException e) {
e.printStackTrace();
return null;
} catch (IOException e) {
e.printStackTrace();
return null;
}
}
读取数据文件中的Label:
public static Matrix getLabel(String pathname) {
{
// String pathname="E:\\Study\\Python_R\\Python_Books\\Machine Learning in Action\\machinelearninginaction\\Ch05\\testSet.txt";
// Matrix matrix=getData(pathname);
// matrixTraverse(matrix);
String line = " ";
List<double[]> list = new ArrayList<double[]>();
try {
FileInputStream fi = new FileInputStream(new File(pathname));
InputStreamReader in = new InputStreamReader(fi);
BufferedReader br = new BufferedReader(in);
// System.out.println(br.readLine().getClass());
//把数据放进list
while ((line = br.readLine()) != null) {//这具br.readLine原来是这样的节奏;
String[] tmp = line.split("\t");
double[] value = new double[1];
value[0] = Double.parseDouble(tmp[2]);
list.add(value);
// line = br.readLine(); //相当于i++,多了;这个错误犯的;
}
// System.out.println("list_size:\t" + list.size());
//把数据放进double[][]
Iterator<double[]> it = list.iterator();
double[][] label = new double[list.size()][1];
for (int i = 0; it.hasNext(); i++) {
double[] tmp = it.next();
label[i] = tmp;
}
Matrix labelMatrix = new Matrix(label);
return labelMatrix;
} catch (IOException e) {
e.printStackTrace();
return null;
}
}
}
算法主要部分,梯度下降法:
public static Matrix gradient() {
String pathname = "E:\\Study\\Python_R\\Python_Books\\Machine Learning in Action\\machinelearninginaction\\Ch05\\testSet.txt";
Matrix matrixLabel = getLabel(pathname);
Matrix matrixData = getData(pathname);
// matrixTraverse(matrixData);
// matrixTraverse(matrixLabel);
//写最速下降法来尝试下;
//初始的weight
double[][] weight = new double[3][1];
weight[0][0] = 1;
weight[1][0] = 1;
weight[2][0] = 1;
Matrix weightMat = new Matrix(weight);
Matrix mm;
Matrix h;
Matrix e;
double alpha = 0.001;
int maxCycles = 500;
// System.out.println("matrixData.times(weightMat):");
// matrixTraverse(matrixData.times(weightMat).times(-1));
for (int i = 1; i < maxCycles; i++) {
mm = matrixData.times(weightMat);
h = sigmoid(mm);
e = matrixLabel.minus(h);
weightMat = weightMat.plus(matrixData.transpose().times(e).times(alpha));//其实这个我没有太看懂;
/* System.out.println("----------------"+i+"--------------------");
for(int j=0; j<weightMat.getRowDimension();j++){
System.out.println(weightMat.get(j,0));
}*/
}
return weightMat;
}
博客今天先写到这里~something worng happens…