理论介绍
节点定义
package logistic;
public class Instance {
public int label;
public double[] x;
public Instance(){}
public Instance(int label,double[] x){
this.label = label;
this.x = x;
}
}
读取数据,并在第一维数值设置为1
package logistic;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class ReadData {
public List<Instance> readDataSet(String fileName,boolean isEnd){
List<Instance> data = new ArrayList<Instance>();
FileReader fileReader = null;
try {
fileReader = new FileReader(fileName);
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
BufferedReader br = new BufferedReader(fileReader);
String row = new String();
try {
while((row=br.readLine())!=null){
String[] A = row.split(",");
double[] x = getDoubleArray(A,isEnd);
int label = 0;
if(isEnd)
label = Integer.parseInt(A[A.length-1]);
else
label = Integer.parseInt(A[0]);
data.add(new Instance(label,x));
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
return data;
}
public double[] getDoubleArray(String[] A,boolean isEnd){
double[] X = new double[A.length];
// label 在第一个位置
int i = 1;
int limit = A.length;
// label 在最后一个位置
if(isEnd){
i=0;
limit = A.length-1;
}
int j = 1;
X[0] = 1.0;// 第一维是偏置,所有x的第一维是 1
for(int k=i;k<limit;k++){
X[j++] = Double.parseDouble(A[k]);
}
return X;
}
}
Logistic.java文件详细分析
定义各个参数
//学习率
private double rate;
// w 权值向量
private double[] weights;
// w 上一次迭代值
private double [] upWeights;
// deltaW
private double[] deltaWeights;
// errSum 训练误差
private double errSum;
// 输入向量的维度
private int size;
// 迭代最大次数
private int limit = 500 ;
// 停止迭代阈值
private double epsilon = 0.0001;
// 训练集
List<Instance> data;
// 随机数
Random random = new Random(201606);
定义不同的构造器
// 不同的构造器
public Logistic(int size){
this.rate=0.001;
this.size = size;
weights = new double[size+1];
upWeights = new double[size+1];
deltaWeights = new double[size+1];
errSum = 0.0;
// upWeights = weights;
randomWeights(weights);
}
public Logistic(int size,double rate){
this(size);
this.rate = rate;
}
public Logistic(int size,double rate,int limit){
this(size,rate);
this.limit = limit;
}
public Logistic(int size,double rate,int limit,double epsilon){
this(size,rate,limit);
this.epsilon = epsilon;
}
public void loadData(String fileName,boolean isEnd){
ReadData readData = new ReadData();
data = readData.readDataSet(fileName, isEnd);
}
随机初始化