bp神经网路 java实现_BP神经网络的Java实现

本文介绍了使用Java实现BP(Back Propagation)神经网络的详细步骤,包括权重初始化、前向传播、误差反向传播和权重调整等核心过程。通过示例展示了如何训练和测试神经网络,用于判断输入数字的正负奇偶性。
摘要由CSDN通过智能技术生成

BP(Back Propagation)网络是1986年由Rumelhart和McCelland为首的科学家小组提出,是一种按误差逆传播算法训练的多层前馈网络,是目前应用最广泛的神经网络模型之一。BP网络能学习和存贮大量的输入-输出模式映射关系,而无需事前揭示描述这种映射关系的数学方程。 import java.util.Random;

/**

* BPNN.

*

* @author RenaQiu

*

*/

public class BP {

/**

* input vector.

*/

private final double[] input;

/**

* hidden layer.

*/

private final double[] hidden;

/**

* output layer.

*/

private final double[] output;

/**

* target.

*/

private final double[] target;

/**

* delta vector of the hidden layer .

*/

private final double[] hidDelta;

/**

* output layer of the output layer.

*/

private final double[] optDelta;

/**

* learning rate.

*/

private final double eta;

/**

* momentum.

*/

private final double momentum;

/**

* weight matrix from input layer to hidden layer.

*/

private final double[][] iptHidWeights;

/**

* weight matrix from hidden layer to output layer.

*/

private final double[][] hidOptWeights;

/**

* previous weight update.

*/

private final double[][] iptHidPrevUptWeights;

/**

* previous weight update.

*/

private final double[][] hidOptPrevUptWeights;

public double optErrSum = 0d;

public double hidErrSum = 0d;

private final Random random;

/**

* Constructor.

*

* Note: The capacity of each layer will be the parameter

* plus 1. The additional unit is used for smoothness.

*

*

* @param inputSize

* @param hiddenSize

* @param outputSize

* @param eta

* @param momentum

* @param epoch

*/

public BP(int inputSize, int hiddenSize, int outputSize, double eta,

double momentum) {

input = new double[inputSize + 1];

hidden = new double[hiddenSize + 1];

output = new double[outputSize + 1];

target = new double[outputSize + 1];

hidDelta = new double[hiddenSize + 1];

optDelta = new double[outputSize + 1];

iptHidWeights = new double[inputSize + 1][hiddenSize + 1];

hidOptWeights = new double[hiddenSize + 1][outputSize + 1];

random = new Random(19881211);

randomizeWeights(iptHidWeights);

randomizeWeights(hidOptWeights);

iptHidPrevUptWeights = new double[inputSize + 1][hiddenSize + 1];

hidOptPrevUptWeights = new double[hiddenSize + 1][outputSize + 1];

this.eta = eta;

this.momentum = momentum;

}

private void randomizeWeights(double[][] matrix) {

for (int i = 0, len = matrix.length; i != len; i++)

for (int j = 0, len2 = matrix[i].length; j != len2; j++) {

double real = random.nextDouble();

matrix[i][j] = random.nextDouble() > 0.5 ? real : -real;

}

}

/**

* Constructor with default eta = 0.25 and momentum = 0.3.

*

* @param inputSize

* @param hiddenSize

* @param outputSize

* @param epoch

*/

public BP(int inputSize, int hiddenSize, int outputSize) {

this(inputSize, hiddenSize, outputSize, 0.25, 0.9);

}

/**

* Entry method. The train data should be a one-dim vector.

*

* @param trainData

* @param target

*/

public void train(double[] trainData, double[] target) {

loadInput(trainData);

loadTarget(target);

forward();

calculateDelta();

adjustWeight();

}

/**

* Test the BPNN.

*

* @param inData

* @return

*/

public double[] test(double[] inData) {

if (inData.length != input.length - 1) {

throw new IllegalArgumentException("Size Do Not Match.");

}

System.arraycopy(inData, 0, input, 1, inData.length);

forward();

return getNetworkOutput();

}

/**

* Return the output layer.

*

* @return

*/

private double[] getNetworkOutput() {

int len = output.length;

double[] temp = new double[len - 1];

for (int i = 1; i != len; i++)

temp[i - 1] = output[i];

return temp;

}

/**

* Load the target data.

*

* @param arg

*/

private void loadTarget(double[] arg) {

if (arg.length != target.length - 1) {

throw new IllegalArgumentException("Size Do Not Match.");

}

System.arraycopy(arg, 0, target, 1, arg.length);

}

/**

* Load the training data.

*

* @param inData

*/

private void loadInput(double[] inData) {

if (inData.length != input.length - 1) {

throw new IllegalArgumentException("Size Do Not Match.");

}

System.arraycopy(inData, 0, input, 1, inData.length);

}

/**

* Forward.

*

* @param layer0

* @param layer1

* @param weight

*/

private void forward(double[] layer0, double[] layer1, double[][] weight) {

// threshold unit.

layer0[0] = 1.0;

for (int j = 1, len = layer1.length; j != len; ++j) {

double sum = 0;

for (int i = 0, len2 = layer0.length; i != len2; ++i)

sum += weight[i][j] * layer0[i];

layer1[j] = sigmoid(sum);

}

}

/**

* Forward.

*/

private void forward() {

forward(input, hidden, iptHidWeights);

forward(hidden, output, hidOptWeights);

}

/**

* Calculate output error.

*/

private void outputErr() {

double errSum = 0;

for (int idx = 1, len = optDelta.length; idx != len; ++idx) {

double o = output[idx];

optDelta[idx] = o * (1d - o) * (target[idx] - o);

errSum += Math.abs(optDelta[idx]);

}

optErrSum = errSum;

}

/**

* Calculate hidden errors.

*/

private void hiddenErr() {

double errSum = 0;

for (int j = 1, len = hidDelta.length; j != len; ++j) {

double o = hidden[j];

double sum = 0;

for (int k = 1, len2 = optDelta.length; k != len2; ++k)

sum += hidOptWeights[j][k] * optDelta[k];

hidDelta[j] = o * (1d - o) * sum;

errSum += Math.abs(hidDelta[j]);

}

hidErrSum = errSum;

}

/**

* Calculate errors of all layers.

*/

private void calculateDelta() {

outputErr();

hiddenErr();

}

/**

* Adjust the weight matrix.

*

* @param delta

* @param layer

* @param weight

* @param prevWeight

*/

private void adjustWeight(double[] delta, double[] layer,

double[][] weight, double[][] prevWeight) {

layer[0] = 1;

for (int i = 1, len = delta.length; i != len; ++i) {

for (int j = 0, len2 = layer.length; j != len2; ++j) {

double newVal = momentum * prevWeight[j][i] + eta * delta[i]

* layer[j];

weight[j][i] += newVal;

prevWeight[j][i] = newVal;

}

}

}

/**

* Adjust all weight matrices.

*/

private void adjustWeight() {

adjustWeight(optDelta, hidden, hidOptWeights, hidOptPrevUptWeights);

adjustWeight(hidDelta, input, iptHidWeights, iptHidPrevUptWeights);

}

/**

* Sigmoid.

*

* @param val

* @return

*/

private double sigmoid(double val) {

return 1d / (1d + Math.exp(-val));

}

}

测试代码:#################################start#################

import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Random; public class Test {     /**      * @param args      * @throws IOException      */     public static void main(String[] args) throws IOException {         BP bp = new BP(32, 15, 4);         Random random = new Random();         List list = new ArrayList();         for (int i = 0; i != 1000; i++) {             int value = random.nextInt();             list.add(value);         }         for (int i = 0; i != 200; i++) {             for (int value : list) {                 double[] real = new double[4];                 if (value >= 0)                     if ((value & 1) == 1)                         real[0] = 1;                     else                         real[1] = 1;                 else if ((value & 1) == 1)                     real[2] = 1;                 else                     real[3] = 1;                 double[] binary = new double[32];                 int index = 31;                 do {                     binary[index--] = (value & 1);                     value >>>= 1;                 } while (value != 0);                 bp.train(binary, real);             }         }         System.out.println("训练完毕,下面请输入一个任意数字,神经网络将自动判断它是正数还是复数,奇数还是偶数。");         while (true) {             byte[] input = new byte[10];             System.in.read(input);             Integer value = Integer.parseInt(new String(input).trim());             int rawVal = value;             double[] binary = new double[32];             int index = 31;             do {                 binary[index--] = (value & 1);                 value >>>= 1;             } while (value != 0);             double[] result = bp.test(binary);             double max = -Integer.MIN_VALUE;             int idx = -1;             for (int i = 0; i != result.length; i++) {                 if (result[i] > max) {                     max = result[i];                     idx = i;                 }             }             switch (idx) {             case 0:                 System.out.format("%d是一个正奇数\n", rawVal);                 break;             case 1:                 System.out.format("%d是一个正偶数\n", rawVal);                 break;             case 2:                 System.out.format("%d是一个负奇数\n", rawVal);                 break;             case 3:                 System.out.format("%d是一个负偶数\n", rawVal);                 break;             }         }     } }

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值