Flink实现逻辑回归并进行性别预测
java实现逻辑回归
基本矩阵类:Matrix
package flinkjava.LR;
import java.util.ArrayList;
/**
* 保存特征信息
* 主要保存特征矩阵
* */
public class Matrix {
/**
* 分为两层ArrayList
* 外面代表行
* 里面代表列
* */
public ArrayList<ArrayList<String>> data;
public Matrix() {
data = new ArrayList<ArrayList<String>>();
}
public ArrayList<ArrayList<String>> getData() {
return data;
}
public void setData(ArrayList<ArrayList<String>> data) {
this.data = data;
}
}
数据集类:包含基本数据Matrix和标签值
package flinkjava.LR;
import java.util.ArrayList;
/***
* 主要保存特征信息以及标签值
* labels:主要保存标签值
* */
public class CreateDataSet extends Matrix{
public ArrayList<String> labels;
public CreateDataSet() {
super();
labels = new ArrayList<String>();
}
public ArrayList<String> getLabels() {
return labels;
}
public void setLabels(ArrayList<String> labels) {
this.labels = labels;
}
}
LR计算模型
package flinkjava.LR;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
public class LR {
/**
* 调用colicTest
* 测试一下即可
* */
public static void main(String[] args) {
colicTest();
}
/**
*/
public static void LogisticTest() {
// TODO Auto-generated method stub
CreateDataSet dataSet = new CreateDataSet();
dataSet = readFile("testSet.txt");
ArrayList<Double> weights = new ArrayList<Double>();
weights = gradAscent1(dataSet, dataSet.labels, 150);
for (int i = 0; i < 3; i++) {
System.out.println(weights.get(i));
}
System.out.println();
}
/**
* @param inX
* @param weights
* @return
*/
public static String classifyVector(ArrayList<String> inX, ArrayList<Double> weights) {
ArrayList<Double> sum = new ArrayList<Double>();
sum.clear();
sum.add(0.0);
for (int i = 0; i < inX.size(); i++) {
sum.set(0, sum.get(0) + Double.parseDouble(inX.get(i)) * weights.get(i));
}
if (sigmoid(sum).get(0) > 0.5)
return "1";
else
return "0";
}
/**
*/
public static void colicTest() {
//创建训练集对象
CreateDataSet trainingSet = new CreateDataSet();
//创建测试集对象
CreateDataSet testSet = new CreateDataSet();
/***
*
* 调用readFile方法
* 将训练集和测试集都读进来
* 形成的是一种矩阵的形式CreateDataSet
* */
trainingSet = readFile("testTraining.txt");// 23 445 34 1 45 56 67 0
testSet = readFile("Test.txt");// 23 445 34 1 45 56 67 0
/**
* 权重值
* */
ArrayList<Double> weights = new ArrayList<Double>();
/**
* 调用gradAccent方法计算
* */
weights = gradAscent1(trainingSet, trainingSet.labels, 500);
/**
* 计算误差
* */
int errorCount = 0;
for (int i = 0; i < testSet.data.size(); i++) {
if (!classifyVector(testSet.data.get(i), weights).equals(testSet.labels.get(i))) {
errorCount++;
}
System.out.println(classifyVector(testSet.data.get(i), weights) + "," + testSet.labels.get(i));
}
System.out.println(1.0 * errorCount / testSet.data.size());
}
/**
* @param inX
* @return
* @Description: [sigmod函数]
*/
public static ArrayList<Double> sigmoid(ArrayList<Double> inX) {
ArrayList<Double> inXExp = new ArrayList<Double>();
for (int i = 0; i < inX.size(); i++) {
inXExp.add(1.0 / (1 + Math.exp(-inX.get(i))));
}
return inXExp;
}
/**
* @param dataSet:训练数据
* @param classLabels:训练数据的labels
* @param numberIter:训练次数
* @return
*/
public static ArrayList<Double> gradAscent1(Matrix dataSet, ArrayList<String> classLabels, int numberIter) {
/**
* m:代表行的个数
* n:代表列的个数,即维度
* alpha:作为梯度下降的幅度,就是在更新权值的时候alpha*梯度,来看权值一次更新多少,故名为步长
* randIndex:作为随机的索引,来随机抽取数据集
* */
int m = dataSet.data.size();
int n = dataSet.data.get(0).size();
double alpha = 0.0;
int randIndex = 0;
/**
* weights:权值
* weightstmp:临时权值
* h:可以不要,就临时保存sigmoid函数后的真是值
* dataIndex:随机抽取数据集的索引集,你也可以按顺序来
* dataMatrixMulweights:保存计算出来的值,即数据集和权值相乘的结果
* */
ArrayList<Double> weights = new ArrayList<Double>();
ArrayList<Double> weightstmp = new ArrayList<Double>();
ArrayList<Double> h = new ArrayList<Double>();
ArrayList<Integer> dataIndex = new ArrayList<Integer>();
ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();
/**
* 初始化权值,暂时都保存为1.0
* 一共有n个权值,因为每一行数据,有n列元素,每个元素应该对应一个权值
* */
for (int i = 0; i < n; i++) {
weights.add(1.0);
weightstmp.add(1.0);
}
dataMatrixMulweights.add(0.0);
/**
* error:保存误差
* for循环开始计算,numberIter是参数传进来的计算次数
* */
double error = 0.0;
for (int j = 0; j < numberIter; j++) {
// 产生0到m的数组
for (int p = 0; p < m; p++) {
dataIndex.add(p);
}
/**
* 每一次迭代计算
* 都要对所有的训练集进行计算
* 即对m条数据集计算
* */
for (int i = 0; i < m; i++) {
alpha = 4 / (1.0 + i + j) + 0.0001;
randIndex = (int) (Math.random() * dataIndex.size());
dataIndex.remove(randIndex);
/***
* 这里temp保存的是某一行数据和权值进行相乘(x1,x2,x3,x4,x5....)* (w1,w2,w3,w4,w5)......
* */
double temp = 0.0;
for (int k = 0; k < n; k++) {
temp = temp + Double.parseDouble(dataSet.data.get(randIndex).get(k)) * weights.get(k);
}
/**
* 将dataMatrixMulWeights:这里其实没必要这样设置,
* 因为temp传到sigmoid函数还是只有一个元素,每次迭代都只有一个元素
* 这里可以改成直接传进去,不需要这个变量
* */
dataMatrixMulweights.set(0, temp);
h = sigmoid(dataMatrixMulweights);
/**
* sigmoid函数出来后的预测值h.get(0),和真实数据集的实际值做比较
* */
error = Double.parseDouble(classLabels.get(randIndex)) - h.get(0);
/**
* 利用梯度下降法,来更新权值
* */
double tempweight = 0.0;
for (int p = 0; p < n; p++) {
tempweight = alpha * Double.parseDouble(dataSet.data.get(randIndex).get(p)) * error;
weights.set(p, weights.get(p) + tempweight);
}
}
}
return weights;
}
/**
* @param dataSet
* @param classLabels
* @return
*/
public static ArrayList<Double> gradAscent0(Matrix dataSet, ArrayList<String> classLabels) {
int m = dataSet.data.size();
int n = dataSet.data.get(0).size();
ArrayList<Double> weights = new ArrayList<Double>();
ArrayList<Double> weightstmp = new ArrayList<Double>();
ArrayList<Double> h = new ArrayList<Double>();
double error = 0.0;
ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();
double alpha = 0.01;
for (int i = 0; i < n; i++) {
weights.add(1.0);
weightstmp.add(1.0);
}
h.add(0.0);
double temp = 0.0;
dataMatrixMulweights.add(0.0);
for (int i = 0; i < m; i++) {
temp = 0.0;
for (int k = 0; k < n; k++) {
temp = temp + Double.parseDouble(dataSet.data.get(i).get(k)) * weights.get(k);
}
dataMatrixMulweights.set(0, temp);
h = sigmoid(dataMatrixMulweights);
error = Double.parseDouble(classLabels.get(i)) - h.get(0);
double tempweight = 0.0;
for (int p = 0; p < n; p++) {
tempweight = alpha * Double.parseDouble(dataSet.data.get(i).get(p)) * error;
weights.set(p, weights.get(p) + tempweight);
}
}
return weights;
}
/**
* @param dataSet
* @param classLabels
* @return
*/
public static ArrayList<Double> gradAscent(Matrix dataSet, ArrayList<String> classLabels) {
int m = dataSet.data.size();
int n = dataSet.data.get(0).size();
ArrayList<Double> weights = new ArrayList<Double>();
ArrayList<Double> weightstmp = new ArrayList<Double>();
ArrayList<Double> h = new ArrayList<Double>();
ArrayList<Double> error = new ArrayList<Double>();
ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();
double alpha = 0.001;
int maxCycles = 500;
for (int i = 0; i < n; i++) {
weights.add(1.0);
weightstmp.add(1.0);
}
for (int i = 0; i < m; i++) {
h.add(0.0);
error.add(0.0);
dataMatrixMulweights.add(0.0);
}
double temp;
for (int i = 0; i < maxCycles; i++) {
for (int j = 0; j < m; j++) {
temp = 0.0;
for (int k = 0; k < n; k++) {
temp = temp + Double.parseDouble(dataSet.data.get(j).get(k)) * weights.get(k);
}
dataMatrixMulweights.set(j, temp);
}
h = sigmoid(dataMatrixMulweights);
for (int q = 0; q < m; q++) {
error.set(q, Double.parseDouble(classLabels.get(q)) - h.get(q));
}
double tempweight = 0.0;
for (int p = 0; p < n; p++) {
tempweight = 0.0;
for (int q = 0; q < m; q++) {
tempweight = tempweight + alpha * Double.parseDouble(dataSet.data.get(q).get(p)) * error.get(q);
}
weights.set(p, weights.get(p) + tempweight);
}
}
return weights;
}
public LR() {
super();
}
/**
* @param fileName
* 读入的文件名
* @return
*/
public static CreateDataSet readFile(String fileName) {
File file = new File(fileName);
BufferedReader reader = null;
CreateDataSet dataSet = new CreateDataSet();
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
// 一次读入一行,直到读入null为文件结束
while ((tempString = reader.readLine()) != null) {
// 显示行号
String[] strArr = tempString.split("\t");
ArrayList<String> as = new ArrayList<String>();
as.add("1");
for (int i = 0; i < strArr.length - 1; i++) {
as.add(strArr[i]);
}
dataSet.data.add(as);
dataSet.labels.add(strArr[strArr.length - 1]);
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return dataSet;
}
}
Flink实现逻辑回归
基本数据类
package flinkjava.LR;
import java.util.ArrayList;
public class LRinfo {
private ArrayList<String> data;
private String label;
private String groupbyfield;
public ArrayList<String> getData() {
return data;
}
public void setData(ArrayList<String> data) {
this.data = data;
}
public String getLabel() {
return label;
}
public void setLabel(String label) {
this.label = label;
}
public String getGroupbyfield() {
return groupbyfield;
}
public void setGroupbyfield(String groupbyfield) {
this.groupbyfield = groupbyfield;
}
}
数据集类:包含基本数据LRinfo和标签值
package flinkjava.LR;
import java.util.ArrayList;
/***
* 主要保存特征信息以及标签值
* labels:主要保存标签值
* */
public class CreateDataSet extends Matrix{
public ArrayList<String> labels;
public CreateDataSet() {
super();
labels = new ArrayList<String>();
}
public ArrayList<String> getLabels() {
return labels;
}
public void setLabels(ArrayList<String> labels) {
this.labels = labels;
}
}
FlinkLR
package flinkjava.LR;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.utils.ParameterTool;
import java.util.*;
public class FlinkLR {
public static void main(String[] args) {
final ParameterTool params = ParameterTool.fromArgs(args);
// set up the execution environment
final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
// make parameters available in the web interface
env.getConfig().setGlobalJobParameters(params);
// get input data
DataSet<String> text = env.readTextFile(params.get("input"));
DataSet<LRinfo> mapresult = text.map(new LRMap());
GroupReduceOperator<LRinfo, ArrayList<Double>> reduceresutl = mapresult.groupBy("groupbyfield").reduceGroup(new LRReduce());
try {
List<ArrayList<Double>> reusltlist = reduceresutl.collect();
int groupsize = reusltlist.size();
Map<Integer,Double> summap = new TreeMap<Integer,Double>(new Comparator<Integer>() {
@Override
public int compare(Integer o1, Integer o2) {
return o1.compareTo(o2);
}
});
for(ArrayList<Double> array:reusltlist){
for(int i=0;i<array.size();i++){
double pre = summap.get(i)==null?0d:summap.get(i);
summap.put(i,pre+array.get(i));
}
}
ArrayList<Double> finalweight = new ArrayList<Double>();
Set<Map.Entry<Integer,Double>> set = summap.entrySet();
for(Map.Entry<Integer,Double> mapentry :set){
Integer key = mapentry.getKey();
Double sumvalue = mapentry.getValue();
double finalvalue = sumvalue/groupsize;
finalweight.add(finalvalue);
}
env.execute("LogicTask analy");
} catch (Exception e) {
e.printStackTrace();
}
}
}
自定义Map操作
package flinkjava.LR;
import org.apache.commons.lang3.StringUtils;
import org.apache.flink.api.common.functions.MapFunction;
import java.util.ArrayList;
import java.util.Random;
public class LRMap implements MapFunction<String,LRinfo> {
@Override
public LRinfo map(String value) throws Exception {
if(StringUtils.isBlank(value)){
return null;
}
Random random = new Random();
String[] temps = value.split(",");
LRinfo lRinfo = new LRinfo();
ArrayList<String> list = new ArrayList<>();
for(int i=0;i<temps.length-1;i++) list.add(temps[i]);
lRinfo.setData(list);
lRinfo.setLabel(temps[temps.length-1]);
lRinfo.setGroupbyfield("logic=="+random.nextInt(10));
return lRinfo;
}
}
自定义ReduceGroup操作
package flinkjava.LR;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.util.Collector;
import java.util.ArrayList;
import java.util.Iterator;
/**
* 输入进来的是一连串的LRinfo
* */
public class LRReduce implements GroupReduceFunction<LRinfo, ArrayList<Double>> {
@Override
public void reduce(Iterable<LRinfo> values, Collector<ArrayList<Double>> out) throws Exception {
Iterator<LRinfo> iterator = values.iterator();
CreateDataSet trainingSet = new CreateDataSet();
while(iterator.hasNext()){
LRinfo lRinfo = iterator.next();
trainingSet.getData().add(lRinfo.getData());
trainingSet.getLabels().add(lRinfo.getLabel());
}
ArrayList<Double> weights = new ArrayList<>();
weights = LR.gradAscent1(trainingSet, trainingSet.labels, 500);
out.collect(weights);
}
}
基于Flink逻辑回归实现“音吧”APP性别预测
这个性别预测是基于“音吧”的APP的性别预测,所以我们要有基本的SexPreInfo,即单位数据,包含了
/**
* 预测性别的基本元素
* 用户id
* 作品总数:worknum
* 配音频次:workfre
* 浏览男生配音次数:manDubWorknum
* 浏览女生配音次数:womanDubWorknum
* 每天浏览作品频次:workDayfre
* 填写的性别作为标签:label
* */
这些数据都是从日志中获取,或者经过处理后再来的
数据基本类
package com.voicebar.Entity;
import java.util.ArrayList;
/**
* 保存特征信息
* 主要保存特征矩阵
* */
public class Matrix {
/**
* 分为两层ArrayList
* 外面代表行
* 里面代表列
* */
public ArrayList<ArrayList<String>> data;
public Matrix() {
data = new ArrayList<ArrayList<String>>();
}
public ArrayList<ArrayList<String>> getData() {
return data;
}
public void setData(ArrayList<ArrayList<String>> data) {
this.data = data;
}
}
含标签的数据类
package com.voicebar.Entity;
import java.util.ArrayList;
/***
* 主要保存特征信息以及标签值
* labels:主要保存标签值
* */
public class CreateDataSet extends Matrix{
public ArrayList<String> labels;
public CreateDataSet() {
super();
labels = new ArrayList<String>();
}
public ArrayList<String> getLabels() {
return labels;
}
public void setLabels(ArrayList<String> labels) {
this.labels = labels;
}
}
一条数据的类
package com.voicebar.Entity;
/**
* 预测性别的基本元素
* 用户id
* 作品总数:worknum
* 配音频次:workfre
* 浏览男生配音次数:manDubWorknum
* 浏览女生配音次数:womanDubWorknum
* 每天浏览作品频次:workDayfre
* 填写的性别作为标签:label
* */
public class SexPreInfo {
private int userid;
private int worknum;
private int workfre;
private int manDubWorknum;
private int womanDubWorknum;
private int workDayfre;
private int label;
public int getUserid() {
return userid;
}
public void setUserid(int userid) {
this.userid = userid;
}
public int getWorknum() {
return worknum;
}
public void setWorknum(int worknum) {
this.worknum = worknum;
}
public int getWorkfre() {
return workfre;
}
public void setWorkfre(int workfre) {
this.workfre = workfre;
}
public int getManDubWorknum() {
return manDubWorknum;
}
public void setManDubWorknum(int manDubWorknum) {
this.manDubWorknum = manDubWorknum;
}
public int getWomanDubWorknum() {
return womanDubWorknum;
}
public void setWomanDubWorknum(int womanDubWorknum) {
this.womanDubWorknum = womanDubWorknum;
}
public int getWorkDayfre() {
return workDayfre;
}
public void setWorkDayfre(int workDayfre) {
this.workDayfre = workDayfre;
}
public int getLabel() {
return label;
}
public void setLabel(int label) {
this.label = label;
}
}
Flink的Task编程
package com.voicebar.task;
import com.voicebar.Entity.SexPreInfo;
import com.voicebar.Map.SexPreMap;
import com.voicebar.Reduce.SexpreReduce;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.utils.ParameterTool;
import java.util.*;
public class SexPreTask {
public static void main(String[] args) {
final ParameterTool params = ParameterTool.fromArgs(args);
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.getConfig().setGlobalJobParameters(params);
DataSource<String> text = env.readTextFile(params.get("input"));
MapOperator<String, SexPreInfo> map = text.map(new SexPreMap());
GroupReduceOperator<SexPreInfo, ArrayList<Double>> reduceresult = map.groupBy("groupfield").reduceGroup(new SexpreReduce());
try {
//将训练好的权值转为List,这里list里面有很多组权值,我们要选一个最好的,误差最小的
List<ArrayList<Double>> reusltlist = reduceresult.collect();
int groupsize = reusltlist.size();
Map<Integer,Double> summap = new TreeMap<Integer,Double>(new Comparator<Integer>() {
public int compare(Integer o1, Integer o2) {
return o1.compareTo(o2);
}
});
for(ArrayList<Double> array:reusltlist){
for(int i=0;i<array.size();i++){
double pre = summap.get(i)==null?0d:summap.get(i);
summap.put(i,pre+array.get(i));
}
}
ArrayList<Double> finalweight = new ArrayList<Double>();
Set<Map.Entry<Integer,Double>> set = summap.entrySet();
// 将所有权值/groupsize
for(Map.Entry<Integer,Double> mapentry :set){
Integer key = mapentry.getKey();
Double sumvalue = mapentry.getValue();
double finalvalue = sumvalue/groupsize;
finalweight.add(finalvalue);
}
env.execute("LogicTask analy");
} catch (Exception e) {
e.printStackTrace();
}
}
}
自定义Map操作
package com.voicebar.Map;
import com.voicebar.Entity.SexPreInfo;
import org.apache.flink.api.common.functions.MapFunction;
import java.util.Random;
/**
*
* private int userid;
* private int worknum;
* private int workfre;
* private int manDubWorknum;
* private int womanDubWorknum;
* private int workDayfre;
* private int label;
* */
public class SexPreMap implements MapFunction<String, SexPreInfo> {
public SexPreInfo map(String value) throws Exception {
String[] temps = value.split(",");
Random random = new Random();
int userid = Integer.valueOf(temps[0]);
int worknum = Integer.valueOf(temps[1]);
int workfre = Integer.valueOf(temps[2]);
int manDubWorknum = Integer.valueOf(temps[3]);
int womanDubWorknum = Integer.valueOf(temps[4]);
int workDayfre = Integer.valueOf(temps[5]);
int label = Integer.valueOf(temps[6]);
String groupfield = "sexpre=="+random.nextInt(10);
SexPreInfo sexPreInfo = new SexPreInfo();
sexPreInfo.setUserid(userid);
sexPreInfo.setWorknum(worknum);
sexPreInfo.setWorkfre(workfre);
sexPreInfo.setManDubWorknum(manDubWorknum);
sexPreInfo.setWomanDubWorknum(womanDubWorknum);
sexPreInfo.setWorkDayfre(workDayfre);
sexPreInfo.setLabel(label);
return sexPreInfo;
}
}
自定义Reduce操作
package com.voicebar.Reduce;
import com.voicebar.Entity.CreateDataSet;
import com.voicebar.Entity.SexPreInfo;
import com.voicebar.Util.LR;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.util.Collector;
import java.util.ArrayList;
import java.util.Iterator;
/**
*
* private int userid;
* private int worknum;
* private int workfre;
* private int manDubWorknum;
* private int womanDubWorknum;
* private int workDayfre;
* private int label;
* */
public class SexpreReduce implements GroupReduceFunction<SexPreInfo, ArrayList<Double>> {
public void reduce(Iterable<SexPreInfo> values, Collector<ArrayList<Double>> out) throws Exception {
Iterator<SexPreInfo> iterator = values.iterator();
CreateDataSet trainingSet = new CreateDataSet();
while(iterator.hasNext()){
SexPreInfo sexPreInfo = iterator.next();
int userid = sexPreInfo.getUserid();
int worknum = sexPreInfo.getWorknum();
int workfre = sexPreInfo.getWorkfre();
int manDubWorknum = sexPreInfo.getManDubWorknum();
int womanDubWorknum = sexPreInfo.getWomanDubWorknum();
int workDayfre = sexPreInfo.getWorkDayfre();
int label = sexPreInfo.getLabel();
ArrayList<String> as = new ArrayList<String>();
as.add(worknum+"");
as.add(workfre+"");
as.add(manDubWorknum+"");
as.add(womanDubWorknum+"");
as.add(workDayfre+"");
as.add(label+"");
trainingSet.getData().add(as);
trainingSet.getLabels().add(label+"");
}
/**
* 将构造好的数据放到已经写好的flink逻辑回归实现里面去计算
* */
ArrayList<Double> weights = new ArrayList<Double>();
weights = LR.gradAscent1(trainingSet,trainingSet.getLabels(),500);
out.collect(weights);
}
}
LR计算模型
package com.voicebar.Util;
import com.voicebar.Entity.CreateDataSet;
import com.voicebar.Entity.Matrix;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
public class LR {
/**
* 调用colicTest
* 测试一下即可
* */
public static void main(String[] args) {
colicTest();
}
/**
*/
public static void LogisticTest() {
// TODO Auto-generated method stub
CreateDataSet dataSet = new CreateDataSet();
dataSet = readFile("testSet.txt");
ArrayList<Double> weights = new ArrayList<Double>();
weights = gradAscent1(dataSet, dataSet.labels, 150);
for (int i = 0; i < 3; i++) {
System.out.println(weights.get(i));
}
System.out.println();
}
/**
* @param inX
* @param weights
* @return
*/
public static String classifyVector(ArrayList<String> inX, ArrayList<Double> weights) {
ArrayList<Double> sum = new ArrayList<Double>();
sum.clear();
sum.add(0.0);
for (int i = 0; i < inX.size(); i++) {
sum.set(0, sum.get(0) + Double.parseDouble(inX.get(i)) * weights.get(i));
}
if (sigmoid(sum).get(0) > 0.5)
return "1";
else
return "0";
}
/**
*/
public static void colicTest() {
//创建训练集对象
CreateDataSet trainingSet = new CreateDataSet();
//创建测试集对象
CreateDataSet testSet = new CreateDataSet();
/***
*
* 调用readFile方法
* 将训练集和测试集都读进来
* 形成的是一种矩阵的形式CreateDataSet
* */
trainingSet = readFile("testTraining.txt");// 23 445 34 1 45 56 67 0
testSet = readFile("Test.txt");// 23 445 34 1 45 56 67 0
/**
* 权重值
* */
ArrayList<Double> weights = new ArrayList<Double>();
/**
* 调用gradAccent方法计算
* */
weights = gradAscent1(trainingSet, trainingSet.labels, 500);
/**
* 计算误差
* */
int errorCount = 0;
for (int i = 0; i < testSet.data.size(); i++) {
if (!classifyVector(testSet.data.get(i), weights).equals(testSet.labels.get(i))) {
errorCount++;
}
System.out.println(classifyVector(testSet.data.get(i), weights) + "," + testSet.labels.get(i));
}
System.out.println(1.0 * errorCount / testSet.data.size());
}
/**
* @param inX
* @return
* @Description: [sigmod函数]
*/
public static ArrayList<Double> sigmoid(ArrayList<Double> inX) {
ArrayList<Double> inXExp = new ArrayList<Double>();
for (int i = 0; i < inX.size(); i++) {
inXExp.add(1.0 / (1 + Math.exp(-inX.get(i))));
}
return inXExp;
}
/**
* @param dataSet:训练数据
* @param classLabels:训练数据的labels
* @param numberIter:训练次数
* @return
*/
public static ArrayList<Double> gradAscent1(Matrix dataSet, ArrayList<String> classLabels, int numberIter) {
/**
* m:代表行的个数
* n:代表列的个数,即维度
* alpha:作为梯度下降的幅度,就是在更新权值的时候alpha*梯度,来看权值一次更新多少,故名为步长
* randIndex:作为随机的索引,来随机抽取数据集
* */
int m = dataSet.data.size();
int n = dataSet.data.get(0).size();
double alpha = 0.0;
int randIndex = 0;
/**
* weights:权值
* weightstmp:临时权值
* h:可以不要,就临时保存sigmoid函数后的真是值
* dataIndex:随机抽取数据集的索引集,你也可以按顺序来
* dataMatrixMulweights:保存计算出来的值,即数据集和权值相乘的结果
* */
ArrayList<Double> weights = new ArrayList<Double>();
ArrayList<Double> weightstmp = new ArrayList<Double>();
ArrayList<Double> h = new ArrayList<Double>();
ArrayList<Integer> dataIndex = new ArrayList<Integer>();
ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();
/**
* 初始化权值,暂时都保存为1.0
* 一共有n个权值,因为每一行数据,有n列元素,每个元素应该对应一个权值
* */
for (int i = 0; i < n; i++) {
weights.add(1.0);
weightstmp.add(1.0);
}
dataMatrixMulweights.add(0.0);
/**
* error:保存误差
* for循环开始计算,numberIter是参数传进来的计算次数
* */
double error = 0.0;
for (int j = 0; j < numberIter; j++) {
// 产生0到m的数组
for (int p = 0; p < m; p++) {
dataIndex.add(p);
}
/**
* 每一次迭代计算
* 都要对所有的训练集进行计算
* 即对m条数据集计算
* */
for (int i = 0; i < m; i++) {
alpha = 4 / (1.0 + i + j) + 0.0001;
randIndex = (int) (Math.random() * dataIndex.size());
dataIndex.remove(randIndex);
/***
* 这里temp保存的是某一行数据和权值进行相乘(x1,x2,x3,x4,x5....)* (w1,w2,w3,w4,w5)......
* */
double temp = 0.0;
for (int k = 0; k < n; k++) {
temp = temp + Double.parseDouble(dataSet.data.get(randIndex).get(k)) * weights.get(k);
}
/**
* 将dataMatrixMulWeights:这里其实没必要这样设置,
* 因为temp传到sigmoid函数还是只有一个元素,每次迭代都只有一个元素
* 这里可以改成直接传进去,不需要这个变量
* */
dataMatrixMulweights.set(0, temp);
h = sigmoid(dataMatrixMulweights);
/**
* sigmoid函数出来后的预测值h.get(0),和真实数据集的实际值做比较
* */
error = Double.parseDouble(classLabels.get(randIndex)) - h.get(0);
/**
* 利用梯度下降法,来更新权值
* */
double tempweight = 0.0;
for (int p = 0; p < n; p++) {
tempweight = alpha * Double.parseDouble(dataSet.data.get(randIndex).get(p)) * error;
weights.set(p, weights.get(p) + tempweight);
}
}
}
return weights;
}
/**
* @param dataSet
* @param classLabels
* @return
*/
public static ArrayList<Double> gradAscent0(Matrix dataSet, ArrayList<String> classLabels) {
int m = dataSet.data.size();
int n = dataSet.data.get(0).size();
ArrayList<Double> weights = new ArrayList<Double>();
ArrayList<Double> weightstmp = new ArrayList<Double>();
ArrayList<Double> h = new ArrayList<Double>();
double error = 0.0;
ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();
double alpha = 0.01;
for (int i = 0; i < n; i++) {
weights.add(1.0);
weightstmp.add(1.0);
}
h.add(0.0);
double temp = 0.0;
dataMatrixMulweights.add(0.0);
for (int i = 0; i < m; i++) {
temp = 0.0;
for (int k = 0; k < n; k++) {
temp = temp + Double.parseDouble(dataSet.data.get(i).get(k)) * weights.get(k);
}
dataMatrixMulweights.set(0, temp);
h = sigmoid(dataMatrixMulweights);
error = Double.parseDouble(classLabels.get(i)) - h.get(0);
double tempweight = 0.0;
for (int p = 0; p < n; p++) {
tempweight = alpha * Double.parseDouble(dataSet.data.get(i).get(p)) * error;
weights.set(p, weights.get(p) + tempweight);
}
}
return weights;
}
/**
* @param dataSet
* @param classLabels
* @return
*/
public static ArrayList<Double> gradAscent(Matrix dataSet, ArrayList<String> classLabels) {
int m = dataSet.data.size();
int n = dataSet.data.get(0).size();
ArrayList<Double> weights = new ArrayList<Double>();
ArrayList<Double> weightstmp = new ArrayList<Double>();
ArrayList<Double> h = new ArrayList<Double>();
ArrayList<Double> error = new ArrayList<Double>();
ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();
double alpha = 0.001;
int maxCycles = 500;
for (int i = 0; i < n; i++) {
weights.add(1.0);
weightstmp.add(1.0);
}
for (int i = 0; i < m; i++) {
h.add(0.0);
error.add(0.0);
dataMatrixMulweights.add(0.0);
}
double temp;
for (int i = 0; i < maxCycles; i++) {
for (int j = 0; j < m; j++) {
temp = 0.0;
for (int k = 0; k < n; k++) {
temp = temp + Double.parseDouble(dataSet.data.get(j).get(k)) * weights.get(k);
}
dataMatrixMulweights.set(j, temp);
}
h = sigmoid(dataMatrixMulweights);
for (int q = 0; q < m; q++) {
error.set(q, Double.parseDouble(classLabels.get(q)) - h.get(q));
}
double tempweight = 0.0;
for (int p = 0; p < n; p++) {
tempweight = 0.0;
for (int q = 0; q < m; q++) {
tempweight = tempweight + alpha * Double.parseDouble(dataSet.data.get(q).get(p)) * error.get(q);
}
weights.set(p, weights.get(p) + tempweight);
}
}
return weights;
}
public LR() {
super();
}
/**
* @param fileName
* 读入的文件名
* @return
*/
public static CreateDataSet readFile(String fileName) {
File file = new File(fileName);
BufferedReader reader = null;
CreateDataSet dataSet = new CreateDataSet();
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
// 一次读入一行,直到读入null为文件结束
while ((tempString = reader.readLine()) != null) {
// 显示行号
String[] strArr = tempString.split("\t");
ArrayList<String> as = new ArrayList<String>();
as.add("1");
for (int i = 0; i < strArr.length - 1; i++) {
as.add(strArr[i]);
}
dataSet.data.add(as);
dataSet.labels.add(strArr[strArr.length - 1]);
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return dataSet;
}
}