[机器学习]BP神经网络 java实现

代码转载处

https://www.cnblogs.com/hesi/p/7218602.html

对他文章的代码进行了修改优化

代码如下

package classfour;

public abstract class Student {
	// 输入层
	protected double[] input;
	// 输入层权重
	protected double[][] inputWeight;
	// 隐藏层
	protected double[] hide;
	// 隐藏层权重
	protected double[][] hideWeight;
	// 输出
	protected double[] out;
	// 实际值
	protected double[] target;
	// 学习率
	protected double rate;

	/**
	 * 初始化
	 * 
	 * @param inputNode
	 *            输入层的个数
	 * @param hideNode
	 *            隐藏层的个数
	 * @param outNode
	 *            输出层的个数
	 * @param rate
	 *            学习率
	 */
	public Student(int inputNode, int hideNode, int outNode, double rate) {
		input = new double[inputNode + 1];
		hide = new double[hideNode + 1];
		out = new double[outNode];
		inputWeight = new double[hideNode][inputNode + 1];
		hideWeight = new double[outNode][hideNode + 1];
		init(inputWeight);
		init(hideWeight);
		this.rate = rate;
	}

	/**
	 * 
	 * @param 初始化权重
	 */
	public void init(double[][] weight) {
		for (int i = 0; i < weight.length; i++) {
			for (int j = 0; j < weight[i].length; j++) {
				weight[i][j] = Math.random()>0.5?Math.random():-Math.random();
			}
		}
	}

	public abstract void study(double[] data, double[] result);

	public double[] predict(double[] data) {
		this.setInput(data);
		double[] output = new double[out.length + 1];
		forward(input, inputWeight, hide);
		forward(hide, hideWeight, output);
		for (int i = 0; i < out.length; i++) {
			out[i] = output[i + 1];
		}
		return out;
	}

	protected void setInput(double[] input) {
		this.input[0] = 1.0;
		for (int i = 0; i < input.length; i++) {
			this.input[i + 1] = input[i];
		}
	}

	/**
	 * 激励函数,这里为逻辑回归
	 * 
	 * @param z
	 * @return
	 */
	protected double Sigmoid(double z) {
		return 1d / (1d + Math.pow(Math.E, -z));
	}

	/**
	 * 代价函数 计算误差
	 * 
	 * @param data
	 * @param result
	 * @return
	 */
	public double costFunction(double[][] data, double[][] result) {
		double sum = 0d;
		for (int i = 0; i < result.length; i++) {
			sum += getSum(data[i], result[i]);
		}
		return -sum / result.length;
	}

	private double getSum(double[] data, double[] result) {
		double[] predict = predict(data);
		double sum = 0d;
		for (int i = 0; i < predict.length; i++) {
			sum += Math.log(predict[i]) * result[i] + (1 - result[i]) * Math.log(1 - predict[i]);
		}
		return sum;
	}

	public void forward(double[] x, double[][] weight, double[] out) {
		out[0] = 1d;
		for (int i = 0; i < weight.length; i++) {
			double sum = 0d;
			for (int j = 0; j < x.length; j++) {
				sum += x[j] * weight[i][j];
			}
			out[i + 1] = Sigmoid(sum);
		}
	}
}
package classfour;

public class WangJunLe extends Student {

	public WangJunLe(int inputNode, int hideNode, int outNode, double rate) {
		super(inputNode, hideNode, outNode, rate);
	}

	@Override
	public void study(double[] data, double[] result) {
		super.setInput(data);
		super.target = result;
		double[] output = new double[out.length + 1];
		// 前向传播
		forward(input, inputWeight, hide);
		forward(hide, hideWeight, output);
		// 后向传播
		backpropagation(output);
	}

	/**
	 * 
	 * @param x为输入层
	 * @param weight为权重
	 * @param out为下一层
	 */
	public void forward(double[] x, double[][] weight, double[] out) {
		out[0] = 1d;
		for (int i = 0; i < weight.length; i++) {
			double sum = 0d;
			for (int j = 0; j < x.length; j++) {
				sum += x[j] * weight[i][j];
			}
			out[i + 1] = Sigmoid(sum);
		}
	}

	/**
	 * 后向传播
	 * 
	 * @param out
	 *            为我们预测的输出结果
	 */
	private void backpropagation(double[] out) {
		double[] error = getError(out);
		double[] hideError = getHideError(error);
		updateWeight(hideError, inputWeight, input);
		updateWeight(error, hideWeight, hide);
	}

	private void updateWeight(double[] error, double[][] weight, double[] x) {
		for (int i = 0; i < weight.length; i++) {
			for (int j = 0; j < weight[i].length; j++) {
				weight[i][j] += rate * error[i] * x[j];
			}
		}
	}

	/**
	 * 误差为实际值减预测值
	 * @param out 为我们预测的输出结果
	 * @return 返回误差
	 */
	private double[] getError(double[] out) {
		double[] error = new double[out.length - 1];
		for (int i = 0; i < target.length; i++) {
			error[i] = target[i] - out[i + 1];
		}
		return error;
	}
	/**
	 * 获取隐藏层的错误
	 * 就是权重*上层误差之和
	 * @param error
	 * @return
	 */
	private double[] getHideError(double[] error) {
		double[] hideError = new double[hide.length - 1];
		for (int i = 0; i < hideError.length; i++) {
			double sum = 0;
			for (int j = 0; j < hideWeight.length; j++) {
				sum += hideWeight[j][i + 1] * error[j];
			}
			hideError[i] = sum * hide[i + 1] * (1d - hide[i + 1]);
		}
		return hideError;
	}

	@Override
	/**
	 * 获取数据集 前向传播得到结果进行返回
	 */
	public double[] predict(double[] data) {
		super.setInput(data);
		double[] output = new double[out.length + 1];
		forward(input, inputWeight, hide);
		forward(hide, hideWeight, output);
		for (int i = 0; i < out.length; i++) {
			out[i] = output[i + 1];
		}
		return out;
	}

}

package classfour;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import classfour.Student;
import classfour.WangJunLe;

public class Exam {

	/**
	 * @param args
	 * @throws IOException
	 */
	public static void main(String[] args) throws IOException {
		Student bp = new WangJunLe(32, 4, 4, 0.05);
		List<Integer> list = new ArrayList<Integer>();
		for (int i = 1; i <= 9; i++) {
			list.add(i);
			list.add(-i);
		}
		for (int i = 0; i != 100000; i++) {
			for (int value : list) {
				double[] real = new double[4];
				if (value >= 0)
					if ((value & 1) == 1)
						real[0] = 1;
					else
						real[1] = 1;
				else if ((value & 1) == 1)
					real[2] = 1;
				else
					real[3] = 1;

				double[] binary = new double[32];
				int index = 31;
				do {
					binary[index--] = (value & 1);
					value >>>= 1;
				} while (value != 0);
				bp.study(binary, real);
			}
		}

		System.out.println("请输入一个任意数字,将自动判断它是正数还是复数,奇数还是偶数。");

		while (true) {
			byte[] input = new byte[10];
			System.in.read(input);
			Integer value = Integer.parseInt(new String(input).trim());
			int rawVal = value;
			double[] binary = new double[32];
			int index = 31;
			do {
				binary[index--] = (value & 1);
				value >>>= 1;
			} while (value != 0);

			double[] result = bp.predict(binary);
			int idx = -1;
			for (int i = 0; i != result.length; i++) {
				if (result[i] > 0.5) {
					idx = i;
					break;
				}
			}
			for (double d : result) {
				System.out.println(d);
			}
			switch (idx) {
			case 0:
				System.out.format("%d是一个正奇数\n", rawVal);
				break;
			case 1:
				System.out.format("%d是一个正偶数\n", rawVal);
				break;
			case 2:
				System.out.format("%d是一个负奇数\n", rawVal);
				break;
			case 3:
				System.out.format("%d是一个负偶数\n", rawVal);
				break;
			}
		}
	}
}

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值