logistic回归 java_【JAVA实现】用Logistic回归进行分类

package logistic;

import java.util.List;

public class Horse {

private List attributes;

private String label;

public List getAttributes() {

return attributes;

}

public void setAttributes(List attributes) {

this.attributes = attributes;

}

public String getLabel() {

return label;

}

public void setLabel(String label) {

this.label = label;

}

}

package logistic;

import java.io.BufferedReader;

import java.io.FileInputStream;

import java.io.IOException;

import java.io.InputStreamReader;

import java.util.ArrayList;

import java.util.List;

import java.util.Random;

public class Logistic {

public List initDataSet(String fileName) {

List dataSet = new ArrayList();

BufferedReader bufferedReader = null;

try {

bufferedReader = new BufferedReader(

new InputStreamReader(

new FileInputStream(

"/home/shenchao/Desktop/MLSourceCode/machinelearninginaction/Ch05/"+fileName)));

String line = null;

while((line = bufferedReader.readLine()) != null) {

Horse horse = new Horse();

String[] s = line.split("\t");

List list = new ArrayList();

for (int i = 0; i < s.length-1; i++) {

list.add(Double.parseDouble(s[i]));

}

horse.setAttributes(list);

horse.setLabel(s[s.length-1]);

dataSet.add(horse);

}

return dataSet;

} catch (Exception e) {

e.printStackTrace();

throw new RuntimeException();

} finally {

try {

bufferedReader.close();

} catch (IOException e) {

e.printStackTrace();

}

}

}

/**

* 改进的随机梯度上升算法

*

* @param trainDataSet 训练集

* @param numIter 迭代次数

* @return 权重向量

*/

public List stocGradAscent(List trainDataSet, int numIter) {

//初始化回归系数

List weights = new ArrayList();

for (int i = 0; i < trainDataSet.get(0).getAttributes().size(); i++) {

weights.add(1.0);

}

for (int i = 0; i < numIter; i++) {

for (int j = 0; j < trainDataSet.size(); j++) {

double alpha = 4.0/(1.0+i+j) + 0.01;

int randIndex = new Random().nextInt(trainDataSet.size());

double h = sigmoid(vecMultipVec(trainDataSet.get(randIndex).getAttributes(), weights));

double error = Double.parseDouble(trainDataSet.get(randIndex).getLabel()) - h;

weights = vecAddVec(weights,alpha, error, trainDataSet.get(randIndex).getAttributes());

trainDataSet.remove(randIndex);

}

}

return weights;

}

private List vecAddVec(List weights, double alpha,

double error, List attributes) {

List list = new ArrayList();

for (int i = 0; i < weights.size(); i++) {

list.add(weights.get(i) + alpha * error * attributes.get(i));

}

return list;

}

/**

* 计算向量的内积

* @param attributes

* @param weights

* @return

*/

private double vecMultipVec(List attributes, List weights) {

double sum = 0.0;

for (int i = 0; i < attributes.size(); i++) {

sum += attributes.get(i) * weights.get(i);

}

return sum;

}

/**

* @param x

* @return

*/

private double sigmoid(double x) {

return 1.0 / (1 + Math.exp(-x));

}

public double test() {

List trainDataSet = initDataSet("horseColicTraining.txt");

List testDataSet = initDataSet("horseColicTest.txt");

List trainWeights = stocGradAscent(trainDataSet, 500);

int errorCount = 0;

for (Horse horse : testDataSet) {

if ((int)classifyVector(horse.getAttributes() , trainWeights) != (int)(Double.parseDouble(horse.getLabel()))) {

++errorCount;

}

}

System.out.println("the error rate of this test is: " + (double) errorCount / testDataSet.size());

return (double) errorCount / testDataSet.size();

}

private double classifyVector(List attributes,

List trainWeights) {

double prob = sigmoid(vecMultipVec(attributes, trainWeights));

if (prob > 0.5) {

return 1.0;

}

return 0.0;

}

public static void main(String[] args) {

Logistic logistic = new Logistic();

double sum = 0.0;

for (int i = 0; i < 10; i++) {

sum += logistic.test();

}

System.out.println("after 10 iterations the average error rate is: " + sum / 10 );

}

}

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值