基于java和Flink实现逻辑回归和“音吧”APP的性别预测

java实现逻辑回归


基本矩阵类:Matrix

package flinkjava.LR;

import java.util.ArrayList;
/**
 * 保存特征信息
 * 主要保存特征矩阵
 * */
public class Matrix {
    /**
     * 分为两层ArrayList
     * 外面代表行
     * 里面代表列
     * */
    public ArrayList<ArrayList<String>> data;
    public Matrix() {
        data = new ArrayList<ArrayList<String>>();

    }

    public ArrayList<ArrayList<String>> getData() {
        return data;
    }

    public void setData(ArrayList<ArrayList<String>> data) {
        this.data = data;
    }
}

数据集类:包含基本数据Matrix和标签值

package flinkjava.LR;

import java.util.ArrayList;
/***
 * 主要保存特征信息以及标签值
 * labels:主要保存标签值
 * */
public class CreateDataSet extends Matrix{
    public ArrayList<String> labels;

    public CreateDataSet() {
        super();
        labels = new ArrayList<String>();
    }

    public ArrayList<String> getLabels() {
        return labels;
    }

    public void setLabels(ArrayList<String> labels) {
        this.labels = labels;
    }
}

LR计算模型

package flinkjava.LR;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;

public class LR {
    /**
     * 调用colicTest
     * 测试一下即可
     * */
    public static void main(String[] args) {
        colicTest();
    }
    /**
     */
    public static void LogisticTest() {
        // TODO Auto-generated method stub
        CreateDataSet dataSet = new CreateDataSet();
        dataSet = readFile("testSet.txt");
        ArrayList<Double> weights = new ArrayList<Double>();
        weights = gradAscent1(dataSet, dataSet.labels, 150);
        for (int i = 0; i < 3; i++) {
            System.out.println(weights.get(i));
        }
        System.out.println();
    }

    /**
     * @param inX
     * @param weights
     * @return
     */
    public static String classifyVector(ArrayList<String> inX, ArrayList<Double> weights) {
        ArrayList<Double> sum = new ArrayList<Double>();
        sum.clear();
        sum.add(0.0);
        for (int i = 0; i < inX.size(); i++) {
            sum.set(0, sum.get(0) + Double.parseDouble(inX.get(i)) * weights.get(i));
        }
        if (sigmoid(sum).get(0) > 0.5)
            return "1";
        else
            return "0";

    }

    /**
     */
    public static void colicTest() {
        //创建训练集对象
        CreateDataSet trainingSet = new CreateDataSet();
        //创建测试集对象
        CreateDataSet testSet = new CreateDataSet();

        /***
         *
         * 调用readFile方法
         * 将训练集和测试集都读进来
         * 形成的是一种矩阵的形式CreateDataSet
         * */
        trainingSet = readFile("testTraining.txt");// 23 445 34 1  45 56 67 0
        testSet = readFile("Test.txt");// 23 445 34 1  45 56 67 0

        /**
         * 权重值
         * */
        ArrayList<Double> weights = new ArrayList<Double>();
        /**
         * 调用gradAccent方法计算
         * */
        weights = gradAscent1(trainingSet, trainingSet.labels, 500);

        /**
         * 计算误差
         * */
        int errorCount = 0;
        for (int i = 0; i < testSet.data.size(); i++) {
            if (!classifyVector(testSet.data.get(i), weights).equals(testSet.labels.get(i))) {
                errorCount++;
            }
            System.out.println(classifyVector(testSet.data.get(i), weights) + "," + testSet.labels.get(i));
        }
        System.out.println(1.0 * errorCount / testSet.data.size());

    }

    /**
     * @param inX
     * @return
     * @Description: [sigmod函数]
     */
    public static ArrayList<Double> sigmoid(ArrayList<Double> inX) {
        ArrayList<Double> inXExp = new ArrayList<Double>();
        for (int i = 0; i < inX.size(); i++) {
            inXExp.add(1.0 / (1 + Math.exp(-inX.get(i))));
        }
        return inXExp;
    }

    /**
     * @param dataSet:训练数据
     * @param classLabels:训练数据的labels
     * @param numberIter:训练次数
     * @return
     */
    public static ArrayList<Double> gradAscent1(Matrix dataSet, ArrayList<String> classLabels, int numberIter) {
        /**
         * m:代表行的个数
         * n:代表列的个数,即维度
         * alpha:作为梯度下降的幅度,就是在更新权值的时候alpha*梯度,来看权值一次更新多少,故名为步长
         * randIndex:作为随机的索引,来随机抽取数据集
         * */
        int m = dataSet.data.size();
        int n = dataSet.data.get(0).size();
        double alpha = 0.0;
        int randIndex = 0;

        /**
         * weights:权值
         * weightstmp:临时权值
         * h:可以不要,就临时保存sigmoid函数后的真是值
         * dataIndex:随机抽取数据集的索引集,你也可以按顺序来
         * dataMatrixMulweights:保存计算出来的值,即数据集和权值相乘的结果
         * */
        ArrayList<Double> weights = new ArrayList<Double>();
        ArrayList<Double> weightstmp = new ArrayList<Double>();
        ArrayList<Double> h = new ArrayList<Double>();
        ArrayList<Integer> dataIndex = new ArrayList<Integer>();
        ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();

        /**
         * 初始化权值,暂时都保存为1.0
         * 一共有n个权值,因为每一行数据,有n列元素,每个元素应该对应一个权值
         * */
        for (int i = 0; i < n; i++) {
            weights.add(1.0);
            weightstmp.add(1.0);
        }


        dataMatrixMulweights.add(0.0);

        /**
         * error:保存误差
         * for循环开始计算,numberIter是参数传进来的计算次数
         * */
        double error = 0.0;
        for (int j = 0; j < numberIter; j++) {
            // 产生0到m的数组
            for (int p = 0; p < m; p++) {
                dataIndex.add(p);
            }
            /**
             * 每一次迭代计算
             * 都要对所有的训练集进行计算
             * 即对m条数据集计算
             * */

            for (int i = 0; i < m; i++) {

                alpha = 4 / (1.0 + i + j) + 0.0001;
                randIndex = (int) (Math.random() * dataIndex.size());
                dataIndex.remove(randIndex);

                /***
                 * 这里temp保存的是某一行数据和权值进行相乘(x1,x2,x3,x4,x5....)* (w1,w2,w3,w4,w5)......
                 * */
                double temp = 0.0;
                for (int k = 0; k < n; k++) {
                    temp = temp + Double.parseDouble(dataSet.data.get(randIndex).get(k)) * weights.get(k);
                }
                
                /**
                 * 将dataMatrixMulWeights:这里其实没必要这样设置,
                 * 因为temp传到sigmoid函数还是只有一个元素,每次迭代都只有一个元素
                 * 这里可以改成直接传进去,不需要这个变量
                 * */
                dataMatrixMulweights.set(0, temp);
                h = sigmoid(dataMatrixMulweights);
                
                /**
                 * sigmoid函数出来后的预测值h.get(0),和真实数据集的实际值做比较
                 * */
                error = Double.parseDouble(classLabels.get(randIndex)) - h.get(0);
                /**
                 * 利用梯度下降法,来更新权值
                 * */
                double tempweight = 0.0;
                for (int p = 0; p < n; p++) {
                    tempweight = alpha * Double.parseDouble(dataSet.data.get(randIndex).get(p)) * error;
                    weights.set(p, weights.get(p) + tempweight);
                }
            }

        }
        return weights;
    }

    /**
     * @param dataSet
     * @param classLabels
     * @return
     */
    public static ArrayList<Double> gradAscent0(Matrix dataSet, ArrayList<String> classLabels) {
        int m = dataSet.data.size();
        int n = dataSet.data.get(0).size();
        ArrayList<Double> weights = new ArrayList<Double>();
        ArrayList<Double> weightstmp = new ArrayList<Double>();
        ArrayList<Double> h = new ArrayList<Double>();
        double error = 0.0;
        ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();
        double alpha = 0.01;
        for (int i = 0; i < n; i++) {
            weights.add(1.0);
            weightstmp.add(1.0);
        }
        h.add(0.0);
        double temp = 0.0;
        dataMatrixMulweights.add(0.0);
        for (int i = 0; i < m; i++) {
            temp = 0.0;
            for (int k = 0; k < n; k++) {
                temp = temp + Double.parseDouble(dataSet.data.get(i).get(k)) * weights.get(k);
            }
            dataMatrixMulweights.set(0, temp);
            h = sigmoid(dataMatrixMulweights);
            error = Double.parseDouble(classLabels.get(i)) - h.get(0);
            double tempweight = 0.0;
            for (int p = 0; p < n; p++) {
                tempweight = alpha * Double.parseDouble(dataSet.data.get(i).get(p)) * error;
                weights.set(p, weights.get(p) + tempweight);
            }
        }
        return weights;
    }

    /**
     * @param dataSet
     * @param classLabels
     * @return
     */
    public static ArrayList<Double> gradAscent(Matrix dataSet, ArrayList<String> classLabels) {
        int m = dataSet.data.size();
        int n = dataSet.data.get(0).size();
        ArrayList<Double> weights = new ArrayList<Double>();
        ArrayList<Double> weightstmp = new ArrayList<Double>();
        ArrayList<Double> h = new ArrayList<Double>();
        ArrayList<Double> error = new ArrayList<Double>();
        ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();
        double alpha = 0.001;
        int maxCycles = 500;
        for (int i = 0; i < n; i++) {
            weights.add(1.0);
            weightstmp.add(1.0);
        }
        for (int i = 0; i < m; i++) {
            h.add(0.0);
            error.add(0.0);
            dataMatrixMulweights.add(0.0);
        }
        double temp;
        for (int i = 0; i < maxCycles; i++) {
            for (int j = 0; j < m; j++) {
                temp = 0.0;
                for (int k = 0; k < n; k++) {
                    temp = temp + Double.parseDouble(dataSet.data.get(j).get(k)) * weights.get(k);
                }
                dataMatrixMulweights.set(j, temp);
            }
            h = sigmoid(dataMatrixMulweights);
            for (int q = 0; q < m; q++) {
                error.set(q, Double.parseDouble(classLabels.get(q)) - h.get(q));
            }
            double tempweight = 0.0;
            for (int p = 0; p < n; p++) {
                tempweight = 0.0;
                for (int q = 0; q < m; q++) {
                    tempweight = tempweight + alpha * Double.parseDouble(dataSet.data.get(q).get(p)) * error.get(q);
                }
                weights.set(p, weights.get(p) + tempweight);
            }
        }
        return weights;
    }

    public LR() {
        super();
    }

    /**
     * @param fileName
     *            读入的文件名
     * @return
     */
    public static CreateDataSet readFile(String fileName) {
        File file = new File(fileName);
        BufferedReader reader = null;
        CreateDataSet dataSet = new CreateDataSet();
        try {
            reader = new BufferedReader(new FileReader(file));
            String tempString = null;
            // 一次读入一行,直到读入null为文件结束
            while ((tempString = reader.readLine()) != null) {
                // 显示行号
                String[] strArr = tempString.split("\t");
                ArrayList<String> as = new ArrayList<String>();
                as.add("1");
                for (int i = 0; i < strArr.length - 1; i++) {
                    as.add(strArr[i]);
                }
                dataSet.data.add(as);
                dataSet.labels.add(strArr[strArr.length - 1]);
            }
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }
        return dataSet;
    }
}

Flink实现逻辑回归


基本数据类

package flinkjava.LR;

import java.util.ArrayList;

public class LRinfo {
    private ArrayList<String> data;
    private String label;
    private String groupbyfield;

    public ArrayList<String> getData() {
        return data;
    }

    public void setData(ArrayList<String> data) {
        this.data = data;
    }

    public String getLabel() {
        return label;
    }

    public void setLabel(String label) {
        this.label = label;
    }

    public String getGroupbyfield() {
        return groupbyfield;
    }

    public void setGroupbyfield(String groupbyfield) {
        this.groupbyfield = groupbyfield;
    }
}


数据集类:包含基本数据LRinfo和标签值

package flinkjava.LR;

import java.util.ArrayList;
/***
 * 主要保存特征信息以及标签值
 * labels:主要保存标签值
 * */
public class CreateDataSet extends Matrix{
    public ArrayList<String> labels;

    public CreateDataSet() {
        super();
        labels = new ArrayList<String>();
    }

    public ArrayList<String> getLabels() {
        return labels;
    }

    public void setLabels(ArrayList<String> labels) {
        this.labels = labels;
    }
}

FlinkLR

package flinkjava.LR;

import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.utils.ParameterTool;

import java.util.*;

public class FlinkLR {
    public static void main(String[] args) {
        final ParameterTool params = ParameterTool.fromArgs(args);

        // set up the execution environment
        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

        // make parameters available in the web interface
        env.getConfig().setGlobalJobParameters(params);

        // get input data
        DataSet<String> text = env.readTextFile(params.get("input"));

        DataSet<LRinfo> mapresult = text.map(new LRMap());
        GroupReduceOperator<LRinfo, ArrayList<Double>> reduceresutl = mapresult.groupBy("groupbyfield").reduceGroup(new LRReduce());
        try {
            List<ArrayList<Double>> reusltlist = reduceresutl.collect();
            int groupsize  = reusltlist.size();
            Map<Integer,Double> summap = new TreeMap<Integer,Double>(new Comparator<Integer>() {
                @Override
                public int compare(Integer o1, Integer o2) {
                    return o1.compareTo(o2);
                }
            });
            for(ArrayList<Double> array:reusltlist){

                for(int i=0;i<array.size();i++){
                    double pre = summap.get(i)==null?0d:summap.get(i);
                    summap.put(i,pre+array.get(i));
                }
            }
            ArrayList<Double> finalweight = new ArrayList<Double>();
            Set<Map.Entry<Integer,Double>> set = summap.entrySet();
            for(Map.Entry<Integer,Double> mapentry :set){
                Integer key = mapentry.getKey();
                Double sumvalue = mapentry.getValue();
                double finalvalue = sumvalue/groupsize;
                finalweight.add(finalvalue);
            }
            env.execute("LogicTask analy");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

自定义Map操作

package flinkjava.LR;

import org.apache.commons.lang3.StringUtils;
import org.apache.flink.api.common.functions.MapFunction;

import java.util.ArrayList;
import java.util.Random;

public class LRMap implements MapFunction<String,LRinfo> {

    @Override
    public LRinfo map(String value) throws Exception {
       if(StringUtils.isBlank(value)){
           return null;
       }

       Random random = new Random();
       String[] temps = value.split(",");
       LRinfo lRinfo = new LRinfo();
       ArrayList<String> list = new ArrayList<>();
       for(int i=0;i<temps.length-1;i++) list.add(temps[i]);

       lRinfo.setData(list);
       lRinfo.setLabel(temps[temps.length-1]);

       lRinfo.setGroupbyfield("logic=="+random.nextInt(10));
       return lRinfo;
    }
}

自定义ReduceGroup操作

package flinkjava.LR;

import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.util.Collector;

import java.util.ArrayList;
import java.util.Iterator;

/**
 * 输入进来的是一连串的LRinfo
 * */
public class LRReduce implements GroupReduceFunction<LRinfo, ArrayList<Double>> {
    @Override
    public void reduce(Iterable<LRinfo> values, Collector<ArrayList<Double>> out) throws Exception {
        Iterator<LRinfo> iterator = values.iterator();
        CreateDataSet trainingSet = new CreateDataSet();
        while(iterator.hasNext()){
            LRinfo lRinfo = iterator.next();
            trainingSet.getData().add(lRinfo.getData());
            trainingSet.getLabels().add(lRinfo.getLabel());
        }
        ArrayList<Double> weights = new ArrayList<>();
        weights = LR.gradAscent1(trainingSet, trainingSet.labels, 500);
        out.collect(weights);
    }
}

基于Flink逻辑回归实现“音吧”APP性别预测


这个性别预测是基于“音吧”的APP的性别预测,所以我们要有基本的SexPreInfo,即单位数据,包含了

/**
 * 预测性别的基本元素
 * 用户id
 * 作品总数:worknum
 * 配音频次:workfre
 * 浏览男生配音次数:manDubWorknum
 * 浏览女生配音次数:womanDubWorknum
 * 每天浏览作品频次:workDayfre
 * 填写的性别作为标签:label
 * */

这些数据都是从日志中获取,或者经过处理后再来的

数据基本类

package com.voicebar.Entity;

import java.util.ArrayList;

/**
 * 保存特征信息
 * 主要保存特征矩阵
 * */
public class Matrix {
    /**
     * 分为两层ArrayList
     * 外面代表行
     * 里面代表列
     * */
    public ArrayList<ArrayList<String>> data;
    public Matrix() {
        data = new ArrayList<ArrayList<String>>();

    }

    public ArrayList<ArrayList<String>> getData() {
        return data;
    }

    public void setData(ArrayList<ArrayList<String>> data) {
        this.data = data;
    }
}

含标签的数据类

package com.voicebar.Entity;

import java.util.ArrayList;

/***
 * 主要保存特征信息以及标签值
 * labels:主要保存标签值
 * */
public class CreateDataSet extends Matrix{
    public ArrayList<String> labels;

    public CreateDataSet() {
        super();
        labels = new ArrayList<String>();
    }

    public ArrayList<String> getLabels() {
        return labels;
    }

    public void setLabels(ArrayList<String> labels) {
        this.labels = labels;
    }
}

一条数据的类

package com.voicebar.Entity;

/**
 * 预测性别的基本元素
 * 用户id
 * 作品总数:worknum
 * 配音频次:workfre
 * 浏览男生配音次数:manDubWorknum
 * 浏览女生配音次数:womanDubWorknum
 * 每天浏览作品频次:workDayfre
 * 填写的性别作为标签:label
 * */
public class SexPreInfo {
    private int userid;
    private int worknum;
    private int workfre;
    private int manDubWorknum;
    private int womanDubWorknum;
    private int workDayfre;
    private int label;

    public int getUserid() {
        return userid;
    }

    public void setUserid(int userid) {
        this.userid = userid;
    }

    public int getWorknum() {
        return worknum;
    }

    public void setWorknum(int worknum) {
        this.worknum = worknum;
    }

    public int getWorkfre() {
        return workfre;
    }

    public void setWorkfre(int workfre) {
        this.workfre = workfre;
    }

    public int getManDubWorknum() {
        return manDubWorknum;
    }

    public void setManDubWorknum(int manDubWorknum) {
        this.manDubWorknum = manDubWorknum;
    }

    public int getWomanDubWorknum() {
        return womanDubWorknum;
    }

    public void setWomanDubWorknum(int womanDubWorknum) {
        this.womanDubWorknum = womanDubWorknum;
    }

    public int getWorkDayfre() {
        return workDayfre;
    }

    public void setWorkDayfre(int workDayfre) {
        this.workDayfre = workDayfre;
    }

    public int getLabel() {
        return label;
    }

    public void setLabel(int label) {
        this.label = label;
    }
}

Flink的Task编程

package com.voicebar.task;

import com.voicebar.Entity.SexPreInfo;
import com.voicebar.Map.SexPreMap;
import com.voicebar.Reduce.SexpreReduce;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.GroupReduceOperator;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.utils.ParameterTool;

import java.util.*;

public class SexPreTask  {
    public static void main(String[] args) {
        final ParameterTool params = ParameterTool.fromArgs(args);
        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        env.getConfig().setGlobalJobParameters(params);

        DataSource<String> text = env.readTextFile(params.get("input"));
        MapOperator<String, SexPreInfo> map = text.map(new SexPreMap());
        GroupReduceOperator<SexPreInfo, ArrayList<Double>> reduceresult = map.groupBy("groupfield").reduceGroup(new SexpreReduce());


        try {
            //将训练好的权值转为List,这里list里面有很多组权值,我们要选一个最好的,误差最小的
            List<ArrayList<Double>> reusltlist = reduceresult.collect();
            int groupsize  = reusltlist.size();
            Map<Integer,Double> summap = new TreeMap<Integer,Double>(new Comparator<Integer>() {

                public int compare(Integer o1, Integer o2) {
                    return o1.compareTo(o2);
                }
            });
            for(ArrayList<Double> array:reusltlist){

                for(int i=0;i<array.size();i++){
                    double pre = summap.get(i)==null?0d:summap.get(i);
                    summap.put(i,pre+array.get(i));
                }
            }
            ArrayList<Double> finalweight = new ArrayList<Double>();
            Set<Map.Entry<Integer,Double>> set = summap.entrySet();
//            将所有权值/groupsize
            for(Map.Entry<Integer,Double> mapentry :set){
                Integer key = mapentry.getKey();
                Double sumvalue = mapentry.getValue();
                double finalvalue = sumvalue/groupsize;
                finalweight.add(finalvalue);
            }
            env.execute("LogicTask analy");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

自定义Map操作

package com.voicebar.Map;

import com.voicebar.Entity.SexPreInfo;
import org.apache.flink.api.common.functions.MapFunction;

import java.util.Random;
/**
 *
 * private int userid;
 *     private int worknum;
 *     private int workfre;
 *     private int manDubWorknum;
 *     private int womanDubWorknum;
 *     private int workDayfre;
 *     private int label;
 * */
public class SexPreMap implements MapFunction<String, SexPreInfo> {
    public SexPreInfo map(String value) throws Exception {
       String[] temps = value.split(",");
       Random random = new Random();

       int userid = Integer.valueOf(temps[0]);
       int worknum = Integer.valueOf(temps[1]);
       int workfre = Integer.valueOf(temps[2]);
       int manDubWorknum = Integer.valueOf(temps[3]);
       int womanDubWorknum = Integer.valueOf(temps[4]);
       int workDayfre = Integer.valueOf(temps[5]);
       int label = Integer.valueOf(temps[6]);

       String groupfield = "sexpre=="+random.nextInt(10);
       SexPreInfo sexPreInfo = new SexPreInfo();
       sexPreInfo.setUserid(userid);
       sexPreInfo.setWorknum(worknum);
       sexPreInfo.setWorkfre(workfre);
       sexPreInfo.setManDubWorknum(manDubWorknum);
       sexPreInfo.setWomanDubWorknum(womanDubWorknum);
       sexPreInfo.setWorkDayfre(workDayfre);
       sexPreInfo.setLabel(label);

       return sexPreInfo;
    }
}

自定义Reduce操作

package com.voicebar.Reduce;

import com.voicebar.Entity.CreateDataSet;
import com.voicebar.Entity.SexPreInfo;
import com.voicebar.Util.LR;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.util.Collector;

import java.util.ArrayList;
import java.util.Iterator;

/**
 *
 * private int userid;
 *     private int worknum;
 *     private int workfre;
 *     private int manDubWorknum;
 *     private int womanDubWorknum;
 *     private int workDayfre;
 *     private int label;
 * */
public class SexpreReduce implements GroupReduceFunction<SexPreInfo, ArrayList<Double>> {
    public void reduce(Iterable<SexPreInfo> values, Collector<ArrayList<Double>> out) throws Exception {
        Iterator<SexPreInfo> iterator = values.iterator();
        CreateDataSet trainingSet = new CreateDataSet();
        while(iterator.hasNext()){
            SexPreInfo sexPreInfo = iterator.next();
            int userid = sexPreInfo.getUserid();
            int worknum = sexPreInfo.getWorknum();
            int workfre = sexPreInfo.getWorkfre();
            int manDubWorknum = sexPreInfo.getManDubWorknum();
            int womanDubWorknum = sexPreInfo.getWomanDubWorknum();
            int workDayfre = sexPreInfo.getWorkDayfre();
            int label = sexPreInfo.getLabel();

            ArrayList<String> as = new ArrayList<String>();
            as.add(worknum+"");
            as.add(workfre+"");
            as.add(manDubWorknum+"");
            as.add(womanDubWorknum+"");
            as.add(workDayfre+"");
            as.add(label+"");

            trainingSet.getData().add(as);
            trainingSet.getLabels().add(label+"");
        }
        /**
         * 将构造好的数据放到已经写好的flink逻辑回归实现里面去计算
         * */
        ArrayList<Double> weights = new ArrayList<Double>();
        weights = LR.gradAscent1(trainingSet,trainingSet.getLabels(),500);
        out.collect(weights);
    }
}

LR计算模型

package com.voicebar.Util;

import com.voicebar.Entity.CreateDataSet;
import com.voicebar.Entity.Matrix;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;

public class LR {
    /**
     * 调用colicTest
     * 测试一下即可
     * */
    public static void main(String[] args) {
        colicTest();
    }
    /**
     */
    public static void LogisticTest() {
        // TODO Auto-generated method stub
        CreateDataSet dataSet = new CreateDataSet();
        dataSet = readFile("testSet.txt");
        ArrayList<Double> weights = new ArrayList<Double>();
        weights = gradAscent1(dataSet, dataSet.labels, 150);
        for (int i = 0; i < 3; i++) {
            System.out.println(weights.get(i));
        }
        System.out.println();
    }

    /**
     * @param inX
     * @param weights
     * @return
     */
    public static String classifyVector(ArrayList<String> inX, ArrayList<Double> weights) {
        ArrayList<Double> sum = new ArrayList<Double>();
        sum.clear();
        sum.add(0.0);
        for (int i = 0; i < inX.size(); i++) {
            sum.set(0, sum.get(0) + Double.parseDouble(inX.get(i)) * weights.get(i));
        }
        if (sigmoid(sum).get(0) > 0.5)
            return "1";
        else
            return "0";

    }

    /**
     */
    public static void colicTest() {
        //创建训练集对象
        CreateDataSet trainingSet = new CreateDataSet();
        //创建测试集对象
        CreateDataSet testSet = new CreateDataSet();

        /***
         *
         * 调用readFile方法
         * 将训练集和测试集都读进来
         * 形成的是一种矩阵的形式CreateDataSet
         * */
        trainingSet = readFile("testTraining.txt");// 23 445 34 1  45 56 67 0
        testSet = readFile("Test.txt");// 23 445 34 1  45 56 67 0

        /**
         * 权重值
         * */
        ArrayList<Double> weights = new ArrayList<Double>();
        /**
         * 调用gradAccent方法计算
         * */
        weights = gradAscent1(trainingSet, trainingSet.labels, 500);

        /**
         * 计算误差
         * */
        int errorCount = 0;
        for (int i = 0; i < testSet.data.size(); i++) {
            if (!classifyVector(testSet.data.get(i), weights).equals(testSet.labels.get(i))) {
                errorCount++;
            }
            System.out.println(classifyVector(testSet.data.get(i), weights) + "," + testSet.labels.get(i));
        }
        System.out.println(1.0 * errorCount / testSet.data.size());

    }

    /**
     * @param inX
     * @return
     * @Description: [sigmod函数]
     */
    public static ArrayList<Double> sigmoid(ArrayList<Double> inX) {
        ArrayList<Double> inXExp = new ArrayList<Double>();
        for (int i = 0; i < inX.size(); i++) {
            inXExp.add(1.0 / (1 + Math.exp(-inX.get(i))));
        }
        return inXExp;
    }

    /**
     * @param dataSet:训练数据
     * @param classLabels:训练数据的labels
     * @param numberIter:训练次数
     * @return
     */
    public static ArrayList<Double> gradAscent1(Matrix dataSet, ArrayList<String> classLabels, int numberIter) {
        /**
         * m:代表行的个数
         * n:代表列的个数,即维度
         * alpha:作为梯度下降的幅度,就是在更新权值的时候alpha*梯度,来看权值一次更新多少,故名为步长
         * randIndex:作为随机的索引,来随机抽取数据集
         * */
        int m = dataSet.data.size();
        int n = dataSet.data.get(0).size();
        double alpha = 0.0;
        int randIndex = 0;

        /**
         * weights:权值
         * weightstmp:临时权值
         * h:可以不要,就临时保存sigmoid函数后的真是值
         * dataIndex:随机抽取数据集的索引集,你也可以按顺序来
         * dataMatrixMulweights:保存计算出来的值,即数据集和权值相乘的结果
         * */
        ArrayList<Double> weights = new ArrayList<Double>();
        ArrayList<Double> weightstmp = new ArrayList<Double>();
        ArrayList<Double> h = new ArrayList<Double>();
        ArrayList<Integer> dataIndex = new ArrayList<Integer>();
        ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();

        /**
         * 初始化权值,暂时都保存为1.0
         * 一共有n个权值,因为每一行数据,有n列元素,每个元素应该对应一个权值
         * */
        for (int i = 0; i < n; i++) {
            weights.add(1.0);
            weightstmp.add(1.0);
        }


        dataMatrixMulweights.add(0.0);

        /**
         * error:保存误差
         * for循环开始计算,numberIter是参数传进来的计算次数
         * */
        double error = 0.0;
        for (int j = 0; j < numberIter; j++) {
            // 产生0到m的数组
            for (int p = 0; p < m; p++) {
                dataIndex.add(p);
            }
            /**
             * 每一次迭代计算
             * 都要对所有的训练集进行计算
             * 即对m条数据集计算
             * */

            for (int i = 0; i < m; i++) {

                alpha = 4 / (1.0 + i + j) + 0.0001;
                randIndex = (int) (Math.random() * dataIndex.size());
                dataIndex.remove(randIndex);

                /***
                 * 这里temp保存的是某一行数据和权值进行相乘(x1,x2,x3,x4,x5....)* (w1,w2,w3,w4,w5)......
                 * */
                double temp = 0.0;
                for (int k = 0; k < n; k++) {
                    temp = temp + Double.parseDouble(dataSet.data.get(randIndex).get(k)) * weights.get(k);
                }

                /**
                 * 将dataMatrixMulWeights:这里其实没必要这样设置,
                 * 因为temp传到sigmoid函数还是只有一个元素,每次迭代都只有一个元素
                 * 这里可以改成直接传进去,不需要这个变量
                 * */
                dataMatrixMulweights.set(0, temp);
                h = sigmoid(dataMatrixMulweights);

                /**
                 * sigmoid函数出来后的预测值h.get(0),和真实数据集的实际值做比较
                 * */
                error = Double.parseDouble(classLabels.get(randIndex)) - h.get(0);
                /**
                 * 利用梯度下降法,来更新权值
                 * */
                double tempweight = 0.0;
                for (int p = 0; p < n; p++) {
                    tempweight = alpha * Double.parseDouble(dataSet.data.get(randIndex).get(p)) * error;
                    weights.set(p, weights.get(p) + tempweight);
                }
            }

        }
        return weights;
    }

    /**
     * @param dataSet
     * @param classLabels
     * @return
     */
    public static ArrayList<Double> gradAscent0(Matrix dataSet, ArrayList<String> classLabels) {
        int m = dataSet.data.size();
        int n = dataSet.data.get(0).size();
        ArrayList<Double> weights = new ArrayList<Double>();
        ArrayList<Double> weightstmp = new ArrayList<Double>();
        ArrayList<Double> h = new ArrayList<Double>();
        double error = 0.0;
        ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();
        double alpha = 0.01;
        for (int i = 0; i < n; i++) {
            weights.add(1.0);
            weightstmp.add(1.0);
        }
        h.add(0.0);
        double temp = 0.0;
        dataMatrixMulweights.add(0.0);
        for (int i = 0; i < m; i++) {
            temp = 0.0;
            for (int k = 0; k < n; k++) {
                temp = temp + Double.parseDouble(dataSet.data.get(i).get(k)) * weights.get(k);
            }
            dataMatrixMulweights.set(0, temp);
            h = sigmoid(dataMatrixMulweights);
            error = Double.parseDouble(classLabels.get(i)) - h.get(0);
            double tempweight = 0.0;
            for (int p = 0; p < n; p++) {
                tempweight = alpha * Double.parseDouble(dataSet.data.get(i).get(p)) * error;
                weights.set(p, weights.get(p) + tempweight);
            }
        }
        return weights;
    }

    /**
     * @param dataSet
     * @param classLabels
     * @return
     */
    public static ArrayList<Double> gradAscent(Matrix dataSet, ArrayList<String> classLabels) {
        int m = dataSet.data.size();
        int n = dataSet.data.get(0).size();
        ArrayList<Double> weights = new ArrayList<Double>();
        ArrayList<Double> weightstmp = new ArrayList<Double>();
        ArrayList<Double> h = new ArrayList<Double>();
        ArrayList<Double> error = new ArrayList<Double>();
        ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();
        double alpha = 0.001;
        int maxCycles = 500;
        for (int i = 0; i < n; i++) {
            weights.add(1.0);
            weightstmp.add(1.0);
        }
        for (int i = 0; i < m; i++) {
            h.add(0.0);
            error.add(0.0);
            dataMatrixMulweights.add(0.0);
        }
        double temp;
        for (int i = 0; i < maxCycles; i++) {
            for (int j = 0; j < m; j++) {
                temp = 0.0;
                for (int k = 0; k < n; k++) {
                    temp = temp + Double.parseDouble(dataSet.data.get(j).get(k)) * weights.get(k);
                }
                dataMatrixMulweights.set(j, temp);
            }
            h = sigmoid(dataMatrixMulweights);
            for (int q = 0; q < m; q++) {
                error.set(q, Double.parseDouble(classLabels.get(q)) - h.get(q));
            }
            double tempweight = 0.0;
            for (int p = 0; p < n; p++) {
                tempweight = 0.0;
                for (int q = 0; q < m; q++) {
                    tempweight = tempweight + alpha * Double.parseDouble(dataSet.data.get(q).get(p)) * error.get(q);
                }
                weights.set(p, weights.get(p) + tempweight);
            }
        }
        return weights;
    }

    public LR() {
        super();
    }

    /**
     * @param fileName
     *            读入的文件名
     * @return
     */
    public static CreateDataSet readFile(String fileName) {
        File file = new File(fileName);
        BufferedReader reader = null;
        CreateDataSet dataSet = new CreateDataSet();
        try {
            reader = new BufferedReader(new FileReader(file));
            String tempString = null;
            // 一次读入一行,直到读入null为文件结束
            while ((tempString = reader.readLine()) != null) {
                // 显示行号
                String[] strArr = tempString.split("\t");
                ArrayList<String> as = new ArrayList<String>();
                as.add("1");
                for (int i = 0; i < strArr.length - 1; i++) {
                    as.add(strArr[i]);
                }
                dataSet.data.add(as);
                dataSet.labels.add(strArr[strArr.length - 1]);
            }
            reader.close();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            if (reader != null) {
                try {
                    reader.close();
                } catch (IOException e1) {
                }
            }
        }
        return dataSet;
    }
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小满锅lock

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值