Java程序员学算法(6) - 反向传播(Back Propagation)

反向传播算法的最常见的表示图就是如下的神经网络的示意图,这个图或类似的图非常常见,看着也很简单。

上图就是三层感知器(即只含有一个隐藏层的多层感知器)。在网上搜了好久终于找到了比较详细的介绍 反向传播算法的文档,大家可以搜: BP 算法原理和详细推导流程.pdf
这个文档里面有详细的每一层的计算方式,根据公式就可以写代码了,主要逻辑就是遍历每一个输入值,要求 训练次数 必须大于数据量。最好是数据量的倍数。但是,有可能造成 训练次数过大,执行会比较慢。代码如下:

public class BackPropagation {
 
    private int[] hiddenUnits;
    private WeightInfo[] weightInfoArrays;
    private int[] fullUnits;
     
    private ActivationFunction activationFunction;
    private ApproximationScore approximationScore;
     
    public BackPropagation(int[] hiddenUnits, int inputFeatureCount, int outputFeatureCount,
            ActivationFunction activationFunction, ApproximationScore approximationScore){
        this.hiddenUnits = hiddenUnits;
        this.activationFunction = activationFunction;
        this.approximationScore = approximationScore;
         
        initFullUnits(inputFeatureCount, outputFeatureCount);
         
    }
     
    /**
     * @param y
     * @param trainCount
     * @param data
     */
    public void train(double[][] y, double learningRate, int trainCount, double[]... data){
         
        if (ArrayUtils.isEmpty(weightInfoArrays)) {
            weightInfoArrays = createWeightInfoArray();
        }
        WeightInfo[] newWeightInfoArray = weightInfoArrays;
         
        Map<Integer, WeightInfo[]> map = new HashMap<Integer, WeightInfo[]>();
 
        int[] equalScoreArray = new int[trainCount];
        int dataLength = data[0].length;
        for (int i = 0; i < trainCount; i++){
             
            int randomIdx = (int)(Math.random() * dataLength);
            if (trainCount > dataLength) {
                randomIdx = i % dataLength;
            }
 
            double[] inputDataArray = new double[data.length];
            for (int j = 0; j < data.length; j++){
                inputDataArray[j] = data[j][randomIdx];
            }
             
            double[] outputDataArray = new double[y.length];
            for (int j = 0; j < y.length; j++) {
                outputDataArray[j] = y[j][randomIdx];
            }
             
            calculatWeightInfoArray(learningRate, inputDataArray, outputDataArray, newWeightInfoArray, showLog);
             
            WeightInfo[] tempWeightInfoArray = cloneWeightInfoArray(newWeightInfoArray);
             
            double[][] predictData = calculateResult(tempWeightInfoArray, showLog, data);
            equalScoreArray[i] = approximationScore.score(y, predictData);
 
 
            map.put(equalScoreArray[i], tempWeightInfoArray); // key: 偏差值, value: 参数数组
        }
         
        int min = getMinValueFrom1D(equalScoreArray); // 获取最小的偏差值
 
        System.out.println("min: " + min);
        WeightInfo[] finalWia = map.get(min);
         
        weightInfoArrays = cloneWeightInfoArray(finalWia);
    }
     
    public double[][] predict(double[]... allXData){
        return calculateResult(weightInfoArrays, true, allXData);
    }
     
    public WeightInfo[] getWeightInfos() {
        return weightInfoArrays;
    }
     
    public void setWeightInfos(WeightInfo[] weightInfos) {
         
        weightInfoArrays = weightInfos;
    }
     
    private double[][] calculateResult(WeightInfo[] newWeightInfoArray, boolean showLog, double[]... allXData){
        int dataCount = allXData[0].length;
        int outputCount = fullUnits[fullUnits.length - 1];
        double[][] ret = new double[outputCount][dataCount];
        for (int i = 0; i < dataCount; i++){           
            double[] inputDataArray = new double[allXData.length];
            for (int j = 0; j < allXData.length; j++){
                inputDataArray[j] = allXData[j][i];
            }
             
            LayerInfo[] lia = forward(inputDataArray, newWeightInfoArray, showLog);
            for (int j = 0; j < outputCount; j++) {
                ret[j][i] = lia[lia.length - 1].getActivationData()[j];
            }
        }
         
        return ret;
    }
     
    private void calculatWeightInfoArray(double learningRate, double[] xdata, double[] ydata,WeightInfo[] newWeightInfoArray, boolean showLog){
        double[] y = ydata;
     
        // forward
        LayerInfo[] layerArray = forward(xdata,newWeightInfoArray, showLog);
        // forward end
               
        // backward output layer
        double[] layerOutputDeltaArray = getLayerOutputDeltaArray(layerArray[layerArray.length - 1], y);
        backwardLayerOutputWeightInfoArray(learningRate,
                                           newWeightInfoArray,
                                           layerOutputDeltaArray,
                                           layerArray[layerArray.length - 2].getActivationData());
         
        // backward hidden layers
        double[] previousLayerDeltaArray = layerOutputDeltaArray;
        for (int i = layerArray.length - 2; i > 0; i--){
            double[] layerDeltaArray = getFixedLayerHiddenDeltaArray(newWeightInfoArray,
                                                                     previousLayerDeltaArray,
                                                                     layerArray[i].getOriginalData(),
                                                                     i);
            backwardLayerWeightInfoArray(learningRate,
                                         newWeightInfoArray,
                                         layerDeltaArray,
                                         layerArray[i - 1].getActivationData(),
                                         i - 1);
           
            previousLayerDeltaArray = layerDeltaArray;
             
        }
         
        // -------- back end
         
        // forward finally
        LayerInfo[] ret = forward(xdata, newWeightInfoArray, showLog);
         
        // forward end       
    }
     
    private LayerInfo[] forward(double[] xdata, WeightInfo[] newWeightInfoArray, boolean showLog){
        int layerCount = fullUnits.length;
         
        LayerInfo[] layerArray = new LayerInfo[layerCount];
        layerArray[0] = new LayerInfo(xdata, xdata);
         
        // forward
        for (int i = 1; i < layerCount; i++){
             
            double[][] layerWeightArray = newWeightInfoArray[i - 1].getWeight();
            double[] layerBiasArray = newWeightInfoArray[i - 1].getBias();
 
            double[] previousLayerDataArray = layerArray[i - 1].getActivationData();
             
            double[] originalData = calculateLayerOrignalData(previousLayerDataArray, layerWeightArray, layerBiasArray);
            double[] activationData = calculateLayerActivationData(originalData);
            layerArray[i] = new LayerInfo(activationData, originalData);
        }
         
        return layerArray;
    }
 
    /**
     * Delta of hidden layer ****************************************************
     * @param previousLayerWArray
     * @param previousLayerDeltaArray
     * @param layerDataArray
     * @return
     */
    private double[] getFixedLayerHiddenDeltaArray(WeightInfo[] newWeightInfoArray,
                                                   double[] previousLayerDeltaArray,
                                                   double[] layerDataArray,
                                                   int layerIndex){
        double[] ret = new double[layerDataArray.length];
        WeightInfo previousWeight = newWeightInfoArray[layerIndex];
         
        for (int i = 0; i < ret.length; i++){
            double sum = 0.0d;
            for (int j = 0; j < previousLayerDeltaArray.length; j++){
                sum = sum + (previousWeight.getWeight()[j][i] * previousLayerDeltaArray[j]);
            }
            ret[i] = activationFunction.derivative(layerDataArray[i]) * sum;
        }
        return ret;
    }
    private void backwardLayerWeightInfoArray(double learningRate,
                                              WeightInfo[] newWeightInfoArray,
                                              double[] layerDeltaArray,
                                              double[] previousLayerDataArray,
                                              int layerIndex){
 
        int weightDeltaLength = layerDeltaArray.length * previousLayerDataArray.length;
        int biasDeltaLength = layerDeltaArray.length;
 
        WeightInfo lastWeightInfo = newWeightInfoArray[layerIndex];
        for (int i = 0; i < lastWeightInfo.getWeight().length; i++){
            for (int j = 0; j < lastWeightInfo.getWeight()[0].length; j++){
                double offset = (learningRate / weightDeltaLength) * (layerDeltaArray[i] * previousLayerDataArray[j]);
                lastWeightInfo.getWeight()[i][j] = lastWeightInfo.getWeight()[i][j] - offset;
            }
        }
         
        for (int i = 0; i < lastWeightInfo.getBias().length; i++){
            lastWeightInfo.getBias()[i] = lastWeightInfo.getBias()[i] -
                    (learningRate / biasDeltaLength) * layerDeltaArray[i];
        }
    }   
     
    /**
     * Delta of output layer ****************************************************
     * @param outputDataArray
     * @param realOutputDataArray
     * @return
     */
    private double[] getLayerOutputDeltaArray(LayerInfo layerInfo,double[] realOutputDataArray){
        double[] ret = new double[layerInfo.getActivationData().length];
        for (int i = 0; i < layerInfo.getActivationData().length; i++){
            double lack = realOutputDataArray[i] - layerInfo.getActivationData()[i];
            lack = -lack;
            double delta = activationFunction.derivative(layerInfo.getOriginalData()[i]) * lack;
            ret[i] = delta;
        }
        return ret;
    }
    private void backwardLayerOutputWeightInfoArray(double learningRate,
                                                    WeightInfo[] newWeightInfoArray,
                                                    double[] layerOutputDeltaArray,
                                                    double[] previousLayerDataArray){
 
        int weightDeltaLength = layerOutputDeltaArray.length * previousLayerDataArray.length;
        int biasDeltaLength = layerOutputDeltaArray.length;
 
        WeightInfo lastWeightInfo = newWeightInfoArray[newWeightInfoArray.length - 1];
        for (int i = 0; i < lastWeightInfo.getWeight().length; i++){
            for (int j = 0; j < lastWeightInfo.getWeight()[0].length; j++){
                double offset = (learningRate / weightDeltaLength) * (layerOutputDeltaArray[i] * previousLayerDataArray[j]);
                lastWeightInfo.getWeight()[i][j] = lastWeightInfo.getWeight()[i][j] - offset;
            }
        }
         
        for (int i = 0; i < lastWeightInfo.getBias().length; i++){
            lastWeightInfo.getBias()[i] = lastWeightInfo.getBias()[i] -
                    (learningRate / biasDeltaLength) * layerOutputDeltaArray[i];
        }
    }
     
    private double[] getFixedLayerOutputBArray(double learningRate, double[] layerOutputBArray, double[] layerOutputDeltaArray){
        double[] ret = new double[layerOutputBArray.length];
         
        double sumDelta = MatrixHelper.sumMatrix1D(layerOutputDeltaArray);
        int deltaLength = layerOutputDeltaArray.length;
         
        for (int i = 0; i < layerOutputBArray.length; i++){
             
            ret[i] = layerOutputBArray[i] - (learningRate / deltaLength) * sumDelta;           
        }
         
        return ret;
    }
    
    private double[] calculateLayerOrignalData(double[] inputData, double[][] w, double[] b){
        double[] ret = MatrixHelper.matrixMultiply(w, inputData);
        for (int i = 0; i < ret.length; i++){
            ret [i] = ret[i] + b[i];
        }
        return ret;
    }
     
    private double[] calculateLayerActivationData(double[] originalData){
        double[] ret = new double[originalData.length];
        for (int i = 0; i < ret.length; i++){
            ret[i] = activationFunction.active(originalData[i]);
        }
        return ret;
    }
  
    private void initFullUnits(int featureCount, int outputCount){
        fullUnits = new int[hiddenUnits.length + 2];
        fullUnits[0] = featureCount;
        fullUnits[fullUnits.length - 1] = outputCount;
        for (int i = 1; i < fullUnits.length - 1; i++){
            fullUnits[i] = hiddenUnits[i - 1];
        }
    }
     
    public WeightInfo[] createWeightInfoArray(){
        WeightInfo[] wia = new WeightInfo[fullUnits.length - 1];
        for (int i = 0; i < wia.length; i++){
            double[][] weightArray = new double[fullUnits[i + 1]][fullUnits[i]];
            assembleWArray(weightArray);
            double[] biasArray = new double[fullUnits[i + 1]];
            assembleBArray(biasArray);
            wia[i] = new WeightInfo(weightArray, biasArray);
        }
         
        return wia;
    }
     
    private void assembleWArray(double[][] warray){
        double start = 0.0;
        int idx = 0;
        for (int i = 0; i < warray.length; i++){
            for (int j = 0; j < warray[0].length; j++){
                warray[i][j] = start + (idx * 0.005);
                idx++;
            }
        }
    }
     
    private void assembleBArray(double[] barray){
        double start = 0.0;
        for (int i = 0; i < barray.length; i++){
            barray[i] = start + 0.005 * i;
        }
    }
     
    private WeightInfo[] cloneWeightInfoArray(WeightInfo[] wia){
        WeightInfo[] ret = new WeightInfo[wia.length];
        for (int i = 0; i < ret.length; i++){
            ret[i] = wia[i].copy();
        }
         
        return ret;
    }
    public static double getMinValueFrom1D(double[] xa){
        double ret = xa[0];
        for (double d : xa){
            if (d < ret){
                ret = d;
            }
        }
        return ret;
    }
}

代码完了之后,就要测试了,目前这版方法非常适合做分类,因此在网上找到了测试数据,可在网上搜索:iris.data  ,测试代码如下:

@Test
public void testIrisData() throws Exception {
    File irisFile = new File("../iris.data");
    // 先把文件解析出来,然后转成数组
    List<String> fileLineList = FileUtils.readLines(irisFile, StandardCharsets.UTF_8);
    // f1,f2,f3,f4 为 4 列输入数据,数据留出30个,用作训练用
    double[] f1 = new double[fileLineList.size() - 30];
    double[] t1 = new double[30]; // t1,t2,t3,t4 将用于记录测试用
 
    double[] f2 = new double[fileLineList.size() - 30];
    double[] t2 = new double[30];
     
    double[] f3 = new double[fileLineList.size() - 30];
    double[] t3 = new double[30];
     
    double[] f4 = new double[fileLineList.size() - 30];
    double[] t4 = new double[30];
     
    double[][] y = new double[3][fileLineList.size() - 30]; // 用作训练用,由于 y 有3种结果,选择 每一组输入项(4个),对应的结果是一组y(3个)
    double[][] t = new double[3][30]; // 用作记录测试用的结果
    int fidx = 0;
    int tidx = 0;
    for (int i = 0; i < fileLineList.size(); i++) {
        if (fileLineList.get(i).trim().length() == 0) {
            continue;
        }
        String[] values = StringUtils.splitPreserveAllTokens(fileLineList.get(i), ',');
        String cat = values[4].trim();
 
        if ((i + 1) % 5 == 0) {
            t1[tidx] = NumberUtils.toDouble(values[0]);
            t2[tidx] = NumberUtils.toDouble(values[1]);
            t3[tidx] = NumberUtils.toDouble(values[2]);
            t4[tidx] = NumberUtils.toDouble(values[3]);
             
            if ("Iris-setosa".equals(cat)) {
                t[0][tidx] = 1;
                t[1][tidx] = 0;
                t[2][tidx] = 0;
            } else if ("Iris-versicolor".equals(cat)) {
                t[0][tidx] = 0;
                t[1][tidx] = 1;
                t[2][tidx] = 0;
            } else if ("Iris-virginica".equals(cat)) {
                t[0][tidx] = 0;
                t[1][tidx] = 0;
                t[2][tidx] = 1;
            }
            tidx++;
            continue;
        }
        f1[fidx] = NumberUtils.toDouble(values[0]);
        f2[fidx] = NumberUtils.toDouble(values[1]);
        f3[fidx] = NumberUtils.toDouble(values[2]);
        f4[fidx] = NumberUtils.toDouble(values[3]);
        if ("Iris-setosa".equals(cat)) {
            y[0][fidx] = 1;
            y[1][fidx] = 0;
            y[2][fidx] = 0;
        } else if ("Iris-versicolor".equals(cat)) {
            y[0][fidx] = 0;
            y[1][fidx] = 1;
            y[2][fidx] = 0;
        } else if ("Iris-virginica".equals(cat)) {
            y[0][fidx] = 0;
            y[1][fidx] = 0;
            y[2][fidx] = 1;
        }
        fidx++;
    }
    System.out.println("datalength: " + f1.length);
     
    BackPropagation bp = new BackPropagation(new int[]{2},4, 3, new Sigmoid(), new ArgmaxEqualScore());
 
    bp.train(y, 0.1, 8001, f1, f2, f3, f4);
 
    double[][] predict = bp.predict(f1, f2, f3, f4);
     
    outputResult(predict, y, "bp train");
     
    predict = bp.predict(t1, t2, t3, t4);
     
    outputResult(predict, t, "bp test");
}
private void outputResult(double[][] predict, double[][] y, String flag) {
    System.out.print(flag + " y argmax: ");
    for (int i = 0; i < y[0].length; i++) {
        double[] tmp = new double[y.length];
     
        for (int j = 0; j < y.length; j++) {
            tmp[j] = y[j][i];
        }
        System.out.print(argmax1D(tmp) + " ");
    }
    System.out.println();
    System.out.print(flag + " p argmax: ");
    for (int i = 0; i < predict[0].length; i++) {
        double[] tmp = new double[predict.length];
        for (int j = 0; j < y.length; j++) {
            tmp[j] = predict[j][i];
        }
     
        System.out.print(argmax1D(tmp) + " ");
    }
    System.out.println();
 
    int ret = 0;
 
    for (int i = 0; i < y[0].length; i++) {
        double[] tmpy = new double[y.length];
        double[] tmpp = new double[predict.length];
 
        for (int j = 0; j < y.length; j++) {
            tmpy[j] = y[j][i];
            tmpp[j] = predict[j][i];   
        }
     
        int ty = argmax1D(tmpy);
        int tp = argmax1D(tmpp);
     
        if (ty != tp) {
            ret++;
        }
    }
    System.out.println(flag + " ret : " + ret);
}
// 一维的argmax函数
public static int argmax1D(double[] a) {
    int idx = 0;
    double temp = a[0];
    for (int i = 0; i < a.length; i++) {
        if (temp < a[i]) {
            idx = i;
            temp = a[i];
        }
    }
 
    return idx;
}

运行此测试,结果如下:

bp train y argmax: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
bp train p argmax: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
bp train ret : 2 

根据训练后的参数数据,代入训练数据进行分类,90个 错了 2个

bp test y argmax: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2
bp test p argmax: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2
bp test ret : 0

代入测试数据,30个全对

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值