package logistic;
import java.util.List;
public class Horse {
private List attributes;
private String label;
public List getAttributes() {
return attributes;
}
public void setAttributes(List attributes) {
this.attributes = attributes;
}
public String getLabel() {
return label;
}
public void setLabel(String label) {
this.label = label;
}
}
package logistic;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class Logistic {
public List initDataSet(String fileName) {
List dataSet = new ArrayList();
BufferedReader bufferedReader = null;
try {
bufferedReader = new BufferedReader(
new InputStreamReader(
new FileInputStream(
"/home/shenchao/Desktop/MLSourceCode/machinelearninginaction/Ch05/"+fileName)));
String line = null;
while((line = bufferedReader.readLine()) != null) {
Horse horse = new Horse();
String[] s = line.split("\t");
List list = new ArrayList();
for (int i = 0; i < s.length-1; i++) {
list.add(Double.parseDouble(s[i]));
}
horse.setAttributes(list);
horse.setLabel(s[s.length-1]);
dataSet.add(horse);
}
return dataSet;
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException();
} finally {
try {
bufferedReader.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
/**
* 改进的随机梯度上升算法
*
* @param trainDataSet 训练集
* @param numIter 迭代次数
* @return 权重向量
*/
public List stocGradAscent(List trainDataSet, int numIter) {
//初始化回归系数
List weights = new ArrayList();
for (int i = 0; i < trainDataSet.get(0).getAttributes().size(); i++) {
weights.add(1.0);
}
for (int i = 0; i < numIter; i++) {
for (int j = 0; j < trainDataSet.size(); j++) {
double alpha = 4.0/(1.0+i+j) + 0.01;
int randIndex = new Random().nextInt(trainDataSet.size());
double h = sigmoid(vecMultipVec(trainDataSet.get(randIndex).getAttributes(), weights));
double error = Double.parseDouble(trainDataSet.get(randIndex).getLabel()) - h;
weights = vecAddVec(weights,alpha, error, trainDataSet.get(randIndex).getAttributes());
trainDataSet.remove(randIndex);
}
}
return weights;
}
private List vecAddVec(List weights, double alpha,
double error, List attributes) {
List list = new ArrayList();
for (int i = 0; i < weights.size(); i++) {
list.add(weights.get(i) + alpha * error * attributes.get(i));
}
return list;
}
/**
* 计算向量的内积
* @param attributes
* @param weights
* @return
*/
private double vecMultipVec(List attributes, List weights) {
double sum = 0.0;
for (int i = 0; i < attributes.size(); i++) {
sum += attributes.get(i) * weights.get(i);
}
return sum;
}
/**
* @param x
* @return
*/
private double sigmoid(double x) {
return 1.0 / (1 + Math.exp(-x));
}
public double test() {
List trainDataSet = initDataSet("horseColicTraining.txt");
List testDataSet = initDataSet("horseColicTest.txt");
List trainWeights = stocGradAscent(trainDataSet, 500);
int errorCount = 0;
for (Horse horse : testDataSet) {
if ((int)classifyVector(horse.getAttributes() , trainWeights) != (int)(Double.parseDouble(horse.getLabel()))) {
++errorCount;
}
}
System.out.println("the error rate of this test is: " + (double) errorCount / testDataSet.size());
return (double) errorCount / testDataSet.size();
}
private double classifyVector(List attributes,
List trainWeights) {
double prob = sigmoid(vecMultipVec(attributes, trainWeights));
if (prob > 0.5) {
return 1.0;
}
return 0.0;
}
public static void main(String[] args) {
Logistic logistic = new Logistic();
double sum = 0.0;
for (int i = 0; i < 10; i++) {
sum += logistic.test();
}
System.out.println("after 10 iterations the average error rate is: " + sum / 10 );
}
}