JavaDay21

学习来源:日撸 Java 三百行(71-80天,BP 神经网络)_闵帆的博客——CSDN博客

bp神经网络(基础类)

1.bp神经网络

        bp神经网络是一种多层前馈神经网络,主要特点是信号正向传递,误差反向传播。

这里写图片描述

图1.多层神经网络结构

2.正向传递过程

        1)输入层输入向量 X = \left ( x_{1},x_{2},x_{3},...,x_{N_{1}} \right ) ;

        2)经过与权值矩阵 w_{ih} 进行运算得到隐含层 Y= \left ( y_{1},y_{2},y_{3},...y_{N_{2}} \right ) ;

        3)隐含层经过与权值矩阵 w_{hj} 的运算得到输出层 Z = \left ( z_{1},z_{2},z_{3},...z_{N_{3}} \right ) 。

3.反向传播过程

        1)根据输出层预测输出与期望输出计算预测误差;

        2)根据预测误差更新权值矩阵 w_{ih} 和 w_{hj} 。

4.例

        输入层向量 X=\left [ 0.8,\space0.6,0.1,0.2,1 \right ] ,权值矩阵 w_{ih} = \left [ 0.5,0.1,0.2,0.2,0.3 \right ] , w_{hj} = \left [ 0.3,0.2,0.1,0.1,0.4 \right ] ,则隐含层为 Y =\left [ 1.35,0.27,0.54,0.54,0.81 \right ] ,输出层为Z = \left [ 1.05,0.7,0.35,0.35,1.4 \right ] 。所以最后的预测分类取输出层中值最大的分类为 Z_{4} 。

代码如下:

package JavaDay21;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;

import weka.core.Instances;

/**
 * General ANN. Two methods are abstract: forward and backPropagation.
 *
 * @author Ke-Xiong Wang
 */
public abstract class GeneralAnn {

    /**
     * The whole dataset.
     */
    Instances dataset;

    /**
     * Number of layers. It is counted according to nodes instead of edges.
     */
    int numLayers;

    /**
     * The number of nodes for each layer, e.g., [3, 4, 6, 2] means that there
     * are 3 input nodes (conditional attributes), 2 hidden layers with 4 and 6
     * nodes, respectively, and 2 class values (binary classification).
     */
    int[] layerNumNodes;

    /**
     * Momentum coefficient.
     */
    public double mobp;

    /**
     * Learning rate.
     */
    public double learningRate;

    /**
     * For random number generation.
     */
    Random random = new Random();

    /**
     ********************
     * The first constructor.
     *
     * @param paraFilename
     *            The arff filename.
     * @param paraLayerNumNodes
     *            The number of nodes for each layer (may be different).
     * @param paraLearningRate
     *            Learning rate.
     * @param paraMobp
     *            Momentum coefficient.
     ********************
     */
    public GeneralAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,
                      double paraMobp) {
        // Step 1. Read data.
        try {
            FileReader tempReader = new FileReader(paraFilename);
            dataset = new Instances(tempReader);
            // The last attribute is the decision class.
            dataset.setClassIndex(dataset.numAttributes() - 1);
            tempReader.close();
        } catch (Exception ee) {
            System.out.println("Error occurred while trying to read \'" + paraFilename
                    + "\' in GeneralAnn constructor.\r\n" + ee);
            System.exit(0);
        } // Of try

        // Step 2. Accept parameters.
        layerNumNodes = paraLayerNumNodes;
        numLayers = layerNumNodes.length;
        // Adjust if necessary.
        layerNumNodes[0] = dataset.numAttributes() - 1;
        layerNumNodes[numLayers - 1] = dataset.numClasses();
        learningRate = paraLearningRate;
        mobp = paraMobp;
    }//Of the first constructor

    /**
     ********************
     * Forward prediction.
     *
     * @param paraInput
     *            The input data of one instance.
     * @return The data at the output end.
     ********************
     */
    public abstract double[] forward(double[] paraInput);

    /**
     ********************
     * Back propagation.
     *
     * @param paraTarget
     *            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
     *
     ********************
     */
    public abstract void backPropagation(double[] paraTarget);

    /**
     ********************
     * Train using the dataset.
     ********************
     */
    public void train() {
        double[] tempInput = new double[dataset.numAttributes() - 1];
        double[] tempTarget = new double[dataset.numClasses()];
        for (int i = 0; i < dataset.numInstances(); i++) {
            // Fill the data.
            for (int j = 0; j < tempInput.length; j++) {
                tempInput[j] = dataset.instance(i).value(j);
            } // Of for j

            // Fill the class label.
            Arrays.fill(tempTarget, 0);
            tempTarget[(int) dataset.instance(i).classValue()] = 1;

            // Train with this instance.
            forward(tempInput);
            backPropagation(tempTarget);
        } // Of for i
    }// Of train

    /**
     ********************
     * Get the index corresponding to the max value of the array.
     *
     * @return the index.
     ********************
     */
    public static int argmax(double[] paraArray) {
        int resultIndex = -1;
        double tempMax = -1e10;
        for (int i = 0; i < paraArray.length; i++) {
            if (tempMax < paraArray[i]) {
                tempMax = paraArray[i];
                resultIndex = i;
            } // Of if
        } // Of for i

        return resultIndex;
    }// Of argmax

    /**
     ********************
     * Test using the dataset.
     *
     * @return The precision.
     ********************
     */
    public double test() {
        double[] tempInput = new double[dataset.numAttributes() - 1];

        double tempNumCorrect = 0;
        double[] tempPrediction;
        int tempPredictedClass = -1;

        for (int i = 0; i < dataset.numInstances(); i++) {
            // Fill the data.
            for (int j = 0; j < tempInput.length; j++) {
                tempInput[j] = dataset.instance(i).value(j);
            } // Of for j

            // Train with this instance.
            tempPrediction = forward(tempInput);
            //System.out.println("prediction: " + Arrays.toString(tempPrediction));
            tempPredictedClass = argmax(tempPrediction);
            if (tempPredictedClass == (int) dataset.instance(i).classValue()) {
                tempNumCorrect++;
            } // Of if
        } // Of for i

        System.out.println("Correct: " + tempNumCorrect + " out of " + dataset.numInstances());

        return tempNumCorrect / dataset.numInstances();
    }// Of test
}//Of class GeneralAnn

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值