Java学习日记(71-80天,BP 神经网络)

学习地址

第 71 天: BP神经网络基础类 (数据读取与基本结构)

BP算法基本原理:

利用输出后的误差来估计输出层的直接前导层的误差,再用这个误差估计更前一层的误差,如此一层一层的反传下去,就获得了所有其他各层的误差估计。
一个三层BP网络:
在这里插入图片描述
激活函数:必须处处可导(一般都使用S型函数)
使用S型激活函数时,BP网络输入与输出关系如下:
输入:
n e t = x 1 w 1 + x 2 w 2 + . . . + x n w n net=x_1w_1+x_2w_2+...+x_nw_n net=x1w1+x2w2+...+xnwn
输出:
y = f ( n e t ) = 1 1 + e − n e t y=f(net)=\frac{1}{1+e^{-net}} y=f(net)=1+enet1
输出的导数:
f ′ ( n e t ) = 1 1 − e − n e t − 1 ( 1 − e − n e t ) 2 = y ( 1 − y ) f'(net)=\frac{1}{1-e^{-net}}-\frac{1}{(1-e^{-net})^2}=y(1-y) f(net)=1enet1(1enet)21=y(1y)
对神经网络进行训练,我们应该尽量将net的值尽量控制在收敛比较快的范围内。

(今天这个程序是为了复用性而强行拆解获得的)

package xjx;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;

import weka.core.Instances;

public abstract class GeneralAnn {

	//数据集
	Instances dataset;

	//层数,它是根据节点而不是边计算的
	int numLayers;

	//每个层的节点数,例如,[3,4,6,2]意味着有3个输入节点(条件属性),2个分别具有4和6个节点的隐藏层,以及2个类值(二进制分类)。
	int[] layerNumNodes;

	//动量系数
	public double mobp;

	//学习率
	public double learningRate;

	//用于随机数生成
	Random random = new Random();

	/**
	 ********************
	 * 第一个构造器
	 * 
	 * @param paraFilename
	 *            arff文件名
	 * @param paraLayerNumNodes
	 *            每一层的结点数(可能不同)
	 * @param paraLearningRate
	 *            学习率
	 * @param paraMobp
	 *            动量系数
	 ********************
	 */
	public GeneralAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,
			double paraMobp) {
		// Step 1. 读取数据.
		try {
			FileReader tempReader = new FileReader(paraFilename);
			dataset = new Instances(tempReader);
			// 最后一个属性是决定类的
			dataset.setClassIndex(dataset.numAttributes() - 1);
			tempReader.close();
		} catch (Exception ee) {
			System.out.println("Error occurred while trying to read \'" + paraFilename
					+ "\' in GeneralAnn constructor.\r\n" + ee);
			System.exit(0);
		}

		// Step 2. 接受参数
		layerNumNodes = paraLayerNumNodes;
		numLayers = layerNumNodes.length;//层数
		// 调整
		layerNumNodes[0] = dataset.numAttributes() - 1;
		layerNumNodes[numLayers - 1] = dataset.numClasses();
		learningRate = paraLearningRate;
		mobp = paraMobp;	
	}
	
	/**
	 ********************
	 * 向前预测
	 * 
	 * @param paraInput
	 *            一个实例的输入数据
	 * @return 输出端的数据
	 ********************
	 */
	public abstract double[] forward(double[] paraInput);

	/**
	 ********************
	 *反向传播
	 * 
	 * @param paraTarget
	 *            对于三类数据,[0,0,1],[0,1,0]或[1,0,0]
	 *            
	 ********************
	 */
	public abstract void backPropagation(double[] paraTarget);

	/**
	 ********************
	 * 使用数据集进行训练
	 ********************
	 */
	public void train() {
		double[] tempInput = new double[dataset.numAttributes() - 1];
		double[] tempTarget = new double[dataset.numClasses()];
		for (int i = 0; i < dataset.numInstances(); i++) {
			// 填充数据
			for (int j = 0; j < tempInput.length; j++) {
				tempInput[j] = dataset.instance(i).value(j);
			}

			//填充类标签
			Arrays.fill(tempTarget, 0);
			tempTarget[(int) dataset.instance(i).classValue()] = 1;

			// 训练实例
			forward(tempInput);
			backPropagation(tempTarget);
		}
	}

	/**
	 ********************
	 * 获取数组最大值对应的索引
	 * 
	 * @return 索引.
	 ********************
	 */
	public static int argmax(double[] paraArray) {
		int resultIndex = -1;
		double tempMax = -1e10;
		for (int i = 0; i < paraArray.length; i++) {
			if (tempMax < paraArray[i]) {
				tempMax = paraArray[i];
				resultIndex = i;
			}
		}

		return resultIndex;
	}

	/**
	 ********************
	 * 使用数据集测试.
	 * 
	 * @return 预测.
	 ********************
	 */
	public double test() {
		double[] tempInput = new double[dataset.numAttributes() - 1];

		double tempNumCorrect = 0;
		double[] tempPrediction;
		int tempPredictedClass = -1;

		for (int i = 0; i < dataset.numInstances(); i++) {
			// 填充数据
			for (int j = 0; j < tempInput.length; j++) {
				tempInput[j] = dataset.instance(i).value(j);
			}

			// 训练实例
			tempPrediction = forward(tempInput);
			//System.out.println("prediction: " + Arrays.toString(tempPrediction));
			tempPredictedClass = argmax(tempPrediction);
			if (tempPredictedClass == (int) dataset.instance(i).classValue()) {
				tempNumCorrect++;
			}
		}

		System.out.println("Correct: " + tempNumCorrect + " out of " + dataset.numInstances());

		return tempNumCorrect / dataset.numInstances();
	}
}

第 72 天: 固定激活函数的BP神经网络 (1. 网络结构理解)

1.layerNumNodes 表示网络基本结构. 如: [3, 4, 6, 2] 表示:
a) 输入端口有 3 个,即数据有 3 个条件属性. 如果与实际数据不符, 代码会自动纠正, 见 GeneralAnn.java 81 行.
b) 输出端口有 2 个, 即数据的决策类别数为 2. 如果与实际数据不符, 代码会自动纠正, 见 GeneralAnn.java 82 行. 对于分类问题, 数据是哪个类别, 对应于输出值最大的端口.
c) 有两个中间层, 分别为 4 个和 6 个节点.

2.layerNodeValues 表示各网络节点的值. 如上例, 网络的节点有 4 层, 即 layerNodeValues.length 为 4. 总结点数为 3 + 4 + 6 + 2 − 15 3 + 4 + 6 + 2 - 153+4+6+2−15 个, 即 layerNodeValues[0].length = 3, layerNodeValues[1].length = 4, layerNodeValues[2].length = 6, layerNodeValues[3].length = 2. Java 支持这种不规则的矩阵 (不同行的列数不同), 因为二维矩阵被当作一维向量的一维向量.

3.layerNodeErrors 表示各网络节点上的误差. 该数组大小于 layerNodeValues 一致.

4.edgeWeights 表示各条边的权重. 由于两层之间的边为多对多关系 (二维数组), 多个层的边就成了三维数组. 例如, 上面例子的第 0 层就应该有 ( 3 + 1 ) × 4 = 16 (3+1) \times 4 = 16(3+1)×4=16 条边, 这里 + 1 +1+1 表示有偏移量 offset. 总共的层数为 4 − 1 = 3 4 - 1 = 34−1=3, 即边的层数要比节点层数少 1. 这也是写程序过程中非常容易出错的地方.

5.edgeWeightsDelta 与 edgeWeights 具有相同大小, 它辅助后者进行调整.

下面是核心代码.

package xjx;

public class SimpleAnn extends GeneralAnn{

	/**
	 * 在转发过程中更改的每个节点的值。第一个维度代表层,第二个维度代表节点。
	 */
	public double[][] layerNodeValues;

	/**
	 * 在反向传播过程中更改的每个节点上的错误。第一个维度代表层,第二个维度代表节点。
	 */
	public double[][] layerNodeErrors;

	/**
	 * 边的权重。第一个维度代表层,第二个维度代表层的节点索引,第三个维度代表下一层的节点索引。
	 */
	public double[][][] edgeWeights;

	/**
	 *边权重的变化。它的数组大小与边权重相同。
	 */
	public double[][][] edgeWeightsDelta;

	/**
	 ********************
	 * The first constructor.
	 * 
	 * @param paraFilename
	 *            arff文件名.
	 * @param paraLayerNumNodes
	 *            每层的节点数(可能不同).
	 * @param paraLearningRate
	 *            学习率.
	 * @param paraMobp
	 *            动量系数.
	 ********************
	 */
	public SimpleAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp) {
		super(paraFilename, paraLayerNumNodes, paraLearningRate, paraMobp);

		// Step 1. 跨层初始化
		layerNodeValues = new double[numLayers][];
		layerNodeErrors = new double[numLayers][];
		edgeWeights = new double[numLayers - 1][][];
		edgeWeightsDelta = new double[numLayers - 1][][];

		// Step 2. 内层初始化
		for (int l = 0; l < numLayers; l++) {
			layerNodeValues[l] = new double[layerNumNodes[l]];
			layerNodeErrors[l] = new double[layerNumNodes[l]];

			// 少了一层,因为每一条边穿过两层。
			if (l + 1 == numLayers) {
				break;
			}

			// 在layerNumNodes[l]+1中,最后一个是为偏移量保留的。
			edgeWeights[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];
			edgeWeightsDelta[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];
			for (int j = 0; j < layerNumNodes[l] + 1; j++) {
				for (int i = 0; i < layerNumNodes[l + 1]; i++) {
					// 初始化权重
					edgeWeights[l][j][i] = random.nextDouble();
				}
			}
		}
	}

	/**
	 ********************
	 * 向前预测.
	 * 
	 * @param paraInput
	 *            The input data of one instance.
	 * @return The data at the output end.
	 ********************
	 */
	public double[] forward(double[] paraInput) {
		//初始化输入层
		for (int i = 0; i < layerNodeValues[0].length; i++) {
			layerNodeValues[0][i] = paraInput[i];
		}

		// 计算每层的节点值
		double z;
		for (int l = 1; l < numLayers; l++) {
			for (int j = 0; j < layerNodeValues[l].length; j++) {
				// 根据偏移量初始化,偏移量总是+1
				z = edgeWeights[l - 1][layerNodeValues[l - 1].length][j];
				// 此节点所有边上的加权和。
				for (int i = 0; i < layerNodeValues[l - 1].length; i++) {
					z += edgeWeights[l - 1][i][j] * layerNodeValues[l - 1][i];
				}

				layerNodeValues[l][j] = 1 / (1 + Math.exp(-z));
			}
		}

		return layerNodeValues[numLayers - 1];
	}

	/**
	 ********************
	 * 反向传播和改变边缘权重。
	 * 
	 * @param paraTarget
	 *            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
	 ********************
	 */
	public void backPropagation(double[] paraTarget) {
		// Step 1. 初始化输出层错误。
		int l = numLayers - 1;
		for (int j = 0; j < layerNodeErrors[l].length; j++) {
			layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j])
					* (paraTarget[j] - layerNodeValues[l][j]);
		}

		// Step 2. l=0时反向传播
		while (l > 0) {
			l--;
			// l层
			for (int j = 0; j < layerNumNodes[l]; j++) {
				double z = 0.0;
				// 对于下一层的每个节点。
				for (int i = 0; i < layerNumNodes[l + 1]; i++) {
					if (l > 0) {
						z += layerNodeErrors[l + 1][i] * edgeWeights[l][j][i];
					}

					// 重量调整
					edgeWeightsDelta[l][j][i] = mobp * edgeWeightsDelta[l][j][i]
							+ learningRate * layerNodeErrors[l + 1][i] * layerNodeValues[l][j];
					edgeWeights[l][j][i] += edgeWeightsDelta[l][j][i];
					if (j == layerNumNodes[l] - 1) {
						// 偏移部分的重量调整。
						edgeWeightsDelta[l][j + 1][i] = mobp * edgeWeightsDelta[l][j + 1][i]
								+ learningRate * layerNodeErrors[l + 1][i];
						edgeWeights[l][j + 1][i] += edgeWeightsDelta[l][j + 1][i];
					}
				}

				//记录错误
				layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * z;
			} 
		}
	}

	/**
	 ********************
	 * 测试算法.
	 ********************
	 */
	public static void main(String[] args) {
		int[] tempLayerNodes = { 4, 8, 8, 3 };
		SimpleAnn tempNetwork = new SimpleAnn("D:/data/iris.arff", tempLayerNodes, 0.01, 0.6);

		for (int round = 0; round < 5000; round++) {
			tempNetwork.train();
		}

		double tempAccuray = tempNetwork.test();
		System.out.println("The accuracy is: " + tempAccuray);
	}
}

运行结果:
在这里插入图片描述

第 73 天: 固定激活函数的BP神经网络 (2. 训练与测试过程理解)

1.Forward 就是利用当前网络对一条数据进行预测的过程.
2.BackPropagation 就是根据误差进行网络权重调节的过程.
3.训练的时候需要前向与后向, 测试的时候只需要前向.
4.这里只实现了 sigmoid 激活函数, 反向传播时的导数与正向传播时的激活函数相对应. 如果要换激活函数, 需要两个地方同时换.

/**
 ********************
 * 向前预测.
 * 
 * @param paraInput
 *            The input data of one instance.
 * @return The data at the output end.
 ********************
 */
public double[] forward(double[] paraInput) {
	//初始化输入层
	for (int i = 0; i < layerNodeValues[0].length; i++) {
		layerNodeValues[0][i] = paraInput[i];
	}

	// 计算每层的节点值
	double z;
	for (int l = 1; l < numLayers; l++) {
		for (int j = 0; j < layerNodeValues[l].length; j++) {
			// 根据偏移量初始化,偏移量总是+1
			z = edgeWeights[l - 1][layerNodeValues[l - 1].length][j];
			// 此节点所有边上的加权和。
			for (int i = 0; i < layerNodeValues[l - 1].length; i++) {
				z += edgeWeights[l - 1][i][j] * layerNodeValues[l - 1][i];
			}

			layerNodeValues[l][j] = 1 / (1 + Math.exp(-z));
		}
	}

	return layerNodeValues[numLayers - 1];
}

/**
 ********************
 * 反向传播和改变边缘权重。
 * 
 * @param paraTarget
 *            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
 ********************
 */
public void backPropagation(double[] paraTarget) {
	// Step 1. 初始化输出层错误。
	int l = numLayers - 1;
	for (int j = 0; j < layerNodeErrors[l].length; j++) {
		layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j])
				* (paraTarget[j] - layerNodeValues[l][j]);
	}

	// Step 2. l=0时反向传播
	while (l > 0) {
		l--;
		// l层
		for (int j = 0; j < layerNumNodes[l]; j++) {
			double z = 0.0;
			// 对于下一层的每个节点。
			for (int i = 0; i < layerNumNodes[l + 1]; i++) {
				if (l > 0) {
					z += layerNodeErrors[l + 1][i] * edgeWeights[l][j][i];
				}

				// 重量调整
				edgeWeightsDelta[l][j][i] = mobp * edgeWeightsDelta[l][j][i]
						+ learningRate * layerNodeErrors[l + 1][i] * layerNodeValues[l][j];
				edgeWeights[l][j][i] += edgeWeightsDelta[l][j][i];
				if (j == layerNumNodes[l] - 1) {
					// 偏移部分的重量调整。
					edgeWeightsDelta[l][j + 1][i] = mobp * edgeWeightsDelta[l][j + 1][i]
							+ learningRate * layerNodeErrors[l + 1][i];
					edgeWeights[l][j + 1][i] += edgeWeightsDelta[l][j + 1][i];
				}
			}

			//记录错误
			layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * z;
		} 
	}
}

第 74 天: 通用BP神经网络 (1. 集中管理激活函数)

激活函数是神经网络的核心。
1.激活与求导是一个, 前者用于 forward, 后者用于 back-propagation.
2.有很多的激活函数, 它们的设计有相应准则, 如分段可导.

ReLU函数又称为修正线性单元(Rectified Linear Unit),是一种分段线性函数,其弥补了sigmoid函数以及tanh函数的梯度消失问题。ReLU函数的公式以及图形如下:
g ( z ) = { z , if z>0  0 , if z<0 g(z)=\begin{cases} z,& \text {if z>0 }\\ 0,& \text {if z<0}\end{cases} g(z)={z,0,if z>0 if z<0
在这里插入图片描述
对于ReLU函数的求导为:
g ′ ( z ) = { 1 , if z>0  0 , if z<0 g'(z)=\begin{cases} 1,& \text {if z>0 }\\ 0,& \text {if z<0}\end{cases} g(z)={1,0,if z>0 if z<0
ReLU函数的优点:
(1)在输入为正数的时候(对于大多数输入 z 空间来说),不存在梯度消失问题。
(2) 计算速度要快很多。ReLU函数只有线性关系,不管是前向传播还是反向传播,都比sigmod和tanh要快很多。(sigmod和tanh要计算指数,计算速度会比较慢)
ReLU函数的缺点:
(1)当输入为负时,梯度为0,会产生梯度消失问题。

代码:

package xjx;

public class Activator {
	/**
	 * Arc tan.
	 */
	public final char ARC_TAN = 'a';

	/**
	 * Elu.
	 */
	public final char ELU = 'e';

	/**
	 * Gelu.
	 */
	public final char GELU = 'g';

	/**
	 * Hard logistic.
	 */
	public final char HARD_LOGISTIC = 'h';

	/**
	 * Identity.
	 */
	public final char IDENTITY = 'i';

	/**
	 * Leaky relu, also known as parametric relu.
	 */
	public final char LEAKY_RELU = 'l';

	/**
	 * Relu.
	 */
	public final char RELU = 'r';

	/**
	 * Soft sign.
	 */
	public final char SOFT_SIGN = 'o';

	/**
	 * Sigmoid.
	 */
	public final char SIGMOID = 's';

	/**
	 * Tanh.
	 */
	public final char TANH = 't';

	/**
	 * Soft plus.
	 */
	public final char SOFT_PLUS = 'u';

	/**
	 * Swish.
	 */
	public final char SWISH = 'w';

	/**
	 * The activator.
	 */
	private char activator;

	/**
	 * Alpha for elu.
	 */
	double alpha;

	/**
	 * Beta for leaky relu.
	 */
	double beta;

	/**
	 * Gamma for leaky relu.
	 */
	double gamma;

	/**
	 *********************
	 * The first constructor.
	 * 
	 * @param paraActivator
	 *            The activator.
	 *********************
	 */
	public Activator(char paraActivator) {
		activator = paraActivator;
	}

	/**
	 *********************
	 * Setter.
	 *********************
	 */
	public void setActivator(char paraActivator) {
		activator = paraActivator;
	}

	/**
	 *********************
	 * Getter.
	 *********************
	 */
	public char getActivator() {
		return activator;
	}

	/**
	 *********************
	 * Setter.
	 *********************
	 */
	void setAlpha(double paraAlpha) {
		alpha = paraAlpha;
	}

	/**
	 *********************
	 * Setter.
	 *********************
	 */
	void setBeta(double paraBeta) {
		beta = paraBeta;
	}

	/**
	 *********************
	 * Setter.
	 *********************
	 */
	void setGamma(double paraGamma) {
		gamma = paraGamma;
	}

	/**
	 *********************
	 * Activate according to the activation function.
	 *********************
	 */
	public double activate(double paraValue) {
		double resultValue = 0;
		switch (activator) {
		case ARC_TAN:
			resultValue = Math.atan(paraValue);
			break;
		case ELU:
			if (paraValue >= 0) {
				resultValue = paraValue;
			} else {
				resultValue = alpha * (Math.exp(paraValue) - 1);
			}
			break;
		// case GELU:
		// resultValue = ?;
		// break;
		// case HARD_LOGISTIC:
		// resultValue = ?;
		// break;
		case IDENTITY:
			resultValue = paraValue;
			break;
		case LEAKY_RELU:
			if (paraValue >= 0) {
				resultValue = paraValue;
			} else {
				resultValue = alpha * paraValue;
			}
			break;
		case SOFT_SIGN:
			if (paraValue >= 0) {
				resultValue = paraValue / (1 + paraValue);
			} else {
				resultValue = paraValue / (1 - paraValue);
			}
			break;
		case SOFT_PLUS:
			resultValue = Math.log(1 + Math.exp(paraValue));
			break;
		case RELU:
			if (paraValue >= 0) {
				resultValue = paraValue;
			} else {
				resultValue = 0;
			}
			break;
		case SIGMOID:
			resultValue = 1 / (1 + Math.exp(-paraValue));
			break;
		case TANH:
			resultValue = 2 / (1 + Math.exp(-2 * paraValue)) - 1;
			break;
		// case SWISH:
		// resultValue = ?;
		// break;
		default:
			System.out.println("Unsupported activator: " + activator);
			System.exit(0);
		}

		return resultValue;
	}

	/**
	 *********************
	 * 根据激活函数导出
	 * 
	 * @param paraValue
	 *            The original value x.
	 * @param paraActivatedValue
	 *            f(x).
	 *********************
	 */
	public double derive(double paraValue, double paraActivatedValue) {
		double resultValue = 0;
		switch (activator) {
		case ARC_TAN:
			resultValue = 1 / (paraValue * paraValue + 1);
			break;
		case ELU:
			if (paraValue >= 0) {
				resultValue = 1;
			} else {
				resultValue = alpha * (Math.exp(paraValue) - 1) + alpha;
			} // Of if
			break;
		// case GELU:
		// resultValue = ?;
		// break;
		// case HARD_LOGISTIC:
		// resultValue = ?;
		// break;
		case IDENTITY:
			resultValue = 1;
			break;
		case LEAKY_RELU:
			if (paraValue >= 0) {
				resultValue = 1;
			} else {
				resultValue = alpha;
			}
			break;
		case SOFT_SIGN:
			if (paraValue >= 0) {
				resultValue = 1 / (1 + paraValue) / (1 + paraValue);
			} else {
				resultValue = 1 / (1 - paraValue) / (1 - paraValue);
			}
			break;
		case SOFT_PLUS:
			resultValue = 1 / (1 + Math.exp(-paraValue));
			break;
		case RELU: // 更新
			if (paraValue >= 0) {
				resultValue = 1;
			} else {
				resultValue = 0;
			}
			break;
		case SIGMOID: // 更新
			resultValue = paraActivatedValue * (1 - paraActivatedValue);
			break;
		case TANH: // 更新
			resultValue = 1 - paraActivatedValue * paraActivatedValue;
			break;
		// case SWISH:
		// resultValue = ?;
		// break;
		default:
			System.out.println("Unsupported activator: " + activator);
			System.exit(0);
		}

		return resultValue;
	}

	/**
	 *********************
	 * 重写对象中声明的方法。
	 *********************
	 */
	public String toString() {
		String resultString = "Activator with function '" + activator + "'";
		resultString += "\r\n alpha = " + alpha + ", beta = " + beta + ", gamma = " + gamma;

		return resultString;
	}

	/**
	 ********************
	 * 测试
	 ********************
	 */
	public static void main(String[] args) {
		Activator tempActivator = new Activator('s');
		double tempValue = 0.6;
		double tempNewValue;
		tempNewValue = tempActivator.activate(tempValue);
		System.out.println("After activation: " + tempNewValue);

		tempNewValue = tempActivator.derive(tempValue, tempNewValue);
		System.out.println("After derive: " + tempNewValue);
	}
}

运行结果:
在这里插入图片描述

第 75 天: 通用BP神经网络 (2. 单层实现)

1.仅实现单层 ANN.
2.正向计算输出, 反向计算误差并调整权值.

输出结果:

Activator: Activator with function 's'
 alpha = 0.0, beta = 0.0, gamma = 0.0
 weights = [[0.6084959193944588, 0.4221456104753831, 0.6183449276687938], [0.7704253816634953, 0.6636288072285302, 0.8794802183018241], [0.17489521516629425, 0.004899930192123647, 0.13474601385167118]]
Forward, the output is: [0.9794693622124504, 0.9561257044911313, 0.9862247642459836]
Back propagation, the error is: [0.03720166959927697, 0.053575467276064444]

代码:

package xjx;

import java.util.Arrays;
import java.util.Random;

public class AnnLayer {

	/**
	 * 输入个数
	 */
	int numInput;

	/**
	 * 输出个数
	 */
	int numOutput;

	/**
	 * 学习率
	 */
	double learningRate;

	/**
	 * 动量系数
	 */
	double mobp;

	/**
	 * 权值矩阵
	 */
	double[][] weights, deltaWeights;

	double[] offset, deltaOffset, errors;

	/**
	 * 输入
	 */
	double[] input;

	/**
	 * 输出
	 */
	double[] output;

	/**
	 * 激活后的输出
	 */
	double[] activatedOutput;

	/**
	 * 输入
	 */
	Activator activator;
	Random random = new Random();

	/**
	 *********************
	 * The first constructor.
	 * 
	 * @param paraActivator
	 *            The activator.
	 *********************
	 */
	public AnnLayer(int paraNumInput, int paraNumOutput, char paraActivator, double paraLearningRate, double paraMobp) {
		numInput = paraNumInput;
		numOutput = paraNumOutput;
		learningRate = paraLearningRate;
		mobp = paraMobp;

		weights = new double[numInput + 1][numOutput];
		deltaWeights = new double[numInput + 1][numOutput];
		for (int i = 0; i < numInput + 1; i++) {
			for (int j = 0; j < numOutput; j++) {
				weights[i][j] = random.nextDouble();
			}
		}

		offset = new double[numOutput];
		deltaOffset = new double[numOutput];
		errors = new double[numInput];

		input = new double[numInput];
		output = new double[numOutput];
		activatedOutput = new double[numOutput];

		activator = new Activator(paraActivator);
	}

	/**
	 ********************
	 * 设置激活器的参数
	 * 
	 * @param paraAlpha
	 *            Alpha. 仅对某些类型有效
	 * @param paraBeta
	 *            Beta.
	 * @param paraAlpha
	 *            Alpha.
	 ********************
	 */
	public void setParameters(double paraAlpha, double paraBeta, double paraGamma) {
		activator.setAlpha(paraAlpha);
		activator.setBeta(paraBeta);
		activator.setGamma(paraGamma);
	}

	/**
	 ********************
	 * 向前预测
	 * 
	 * @param paraInput
	 *            The input data of one instance.
	 * @return The data at the output end.
	 ********************
	 */
	public double[] forward(double[] paraInput) {
		//System.out.println("Ann layer forward " + Arrays.toString(paraInput));
		// 复制数据
		for (int i = 0; i < numInput; i++) {
			input[i] = paraInput[i];
		}

		// 计算每个输出的加权和
		for (int i = 0; i < numOutput; i++) {
			output[i] = weights[numInput][i];
			for (int j = 0; j < numInput; j++) {
				output[i] += input[j] * weights[j][i];
			}

			activatedOutput[i] = activator.activate(output[i]);
		}

		return activatedOutput;
	}

	/**
	 ********************
	 * 反向传播和改变边缘权重
	 * 
	 * @param paraTarget
	 *            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
	 ********************
	 */
	public double[] backPropagation(double[] paraErrors) {
		//Step 1. 调整错误
		for (int i = 0; i < paraErrors.length; i++) {
			paraErrors[i] = activator.derive(output[i], activatedOutput[i]) * paraErrors[i];
		}
		//Step 2. 计算当前错误
		for (int i = 0; i < numInput; i++) {
			errors[i] = 0;
			for (int j = 0; j < numOutput; j++) {
				errors[i] += paraErrors[j] * weights[i][j];
				deltaWeights[i][j] = mobp * deltaWeights[i][j] + learningRate * paraErrors[j] * input[i];
				weights[i][j] += deltaWeights[i][j];

				if (i == numInput - 1) {
					// 偏移量调整
					deltaOffset[j] = mobp * deltaOffset[j] + learningRate * paraErrors[j];
					offset[j] += deltaOffset[j];
				}
			}
		}

		return errors;
	}

	/**
	 ********************
	 * I am the last layer, set the errors.
	 * 
	 * @param paraTarget
	 *            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
	 ********************
	 */
	public double[] getLastLayerErrors(double[] paraTarget) {
		double[] resultErrors = new double[numOutput];
		for (int i = 0; i < numOutput; i++) {
			resultErrors[i] = (paraTarget[i] - activatedOutput[i]);
		}
		return resultErrors;
	}

	/**
	 ********************
	 * Show me.
	 ********************
	 */
	public String toString() {
		String resultString = "";
		resultString += "Activator: " + activator;
		resultString += "\r\n weights = " + Arrays.deepToString(weights);
		return resultString;
	}

	/**
	 ********************
	 * Unit test.
	 ********************
	 */
	public static void unitTest() {
		AnnLayer tempLayer = new AnnLayer(2, 3, 's', 0.01, 0.1);
		double[] tempInput = { 1, 4 };

		System.out.println(tempLayer);

		double[] tempOutput = tempLayer.forward(tempInput);
		System.out.println("Forward, the output is: " + Arrays.toString(tempOutput));

		double[] tempError = tempLayer.backPropagation(tempOutput);
		System.out.println("Back propagation, the error is: " + Arrays.toString(tempError));
	}

	/**
	 ********************
	 * Test the algorithm.
	 ********************
	 */
	public static void main(String[] args) {
		unitTest();
	}
}

第 76 天: 通用BP神经网络 (3. 综合测试)

测试结果:

Correct: 146.0 out of 150
The accuracy is: 0.9733333333333334
FullAnn ends.

代码:

package xjx;

import java.util.Arrays;

public class FullAnn extends GeneralAnn {

	/**
	 * 层
	 */
	AnnLayer[] layers;

	/**
	 ********************
	 * The first constructor.
	 * 
	 * @param paraFilename
	 *            The arff filename.
	 * @param paraLayerNumNodes
	 *            The number of nodes for each layer (may be different).
	 * @param paraLearningRate
	 *            Learning rate.
	 * @param paraMobp
	 *            Momentum coefficient.
	 * @param paraActivators The storing the activators of each layer.
	 ********************
	 */
	public FullAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp, String paraActivators) {
		super(paraFilename, paraLayerNumNodes, paraLearningRate, paraMobp);

		// 初始化层.
		layers = new AnnLayer[numLayers - 1];
		for (int i = 0; i < layers.length; i++) {
			layers[i] = new AnnLayer(layerNumNodes[i], layerNumNodes[i + 1], paraActivators.charAt(i), paraLearningRate, paraMobp);
		}
	}

	/**
	 ********************
	 * 向上预测
	 * 
	 * @param paraInput
	 *            The input data of one instance.
	 * @return The data at the output end.
	 ********************
	 */
	public double[] forward(double[] paraInput) {
		double[] resultArray = paraInput;
		for(int i = 0; i < numLayers - 1; i ++) {
			resultArray = layers[i].forward(resultArray);
		}
		
		return resultArray;
	}

	/**
	 ********************
	 * 反向传播
	 * 
	 * @param paraTarget
	 *            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
	 *            
	 ********************
	 */
	public void backPropagation(double[] paraTarget) {
		double[] tempErrors = layers[numLayers - 2].getLastLayerErrors(paraTarget);
		for (int i = numLayers - 2; i >= 0; i--) {
			tempErrors = layers[i].backPropagation(tempErrors);
		}
		
		return;
	}

	/**
	 ********************
	 * Show me.
	 ********************
	 */
	public String toString() {
		String resultString = "I am a full ANN with " + numLayers + " layers";
		return resultString;
	}

	/**
	 ********************
	 * 测试.
	 ********************
	 */
	public static void main(String[] args) {
		int[] tempLayerNodes = { 4, 8, 8, 3 };
		FullAnn tempNetwork = new FullAnn("D:/data/iris.arff", tempLayerNodes, 0.01, 0.6, "sss");

		for (int round = 0; round < 5000; round++) {
			tempNetwork.train();
		}

		double tempAccuray = tempNetwork.test();
		System.out.println("The accuracy is: " + tempAccuray);
		System.out.println("FullAnn ends.");
	}
}

第 77 天: GUI (1. 对话框相关控件)

先在eclipse官网选择对应版本的GUI插件:安装网址
在eclipse中安装:Help->Install New Software…
安装完重启编译器
然后新建项目,New→Project→WindowBuilder→SWT Designer→SWT/JFace Java Project,建立一个包,在建类的时候选择New→Other,选择WindowBuilder→Swing Designer→Application Window.类建好之后点击Design就可以进行可视化编辑了。
但是在引用java.awt.event时会报错,必须在modules.java里声明 requires java.desktop;就不会报错了
在这里插入图片描述
可在design里面设计:
在这里插入图片描述
代码说明:
ApplicationShowdown.java 仅用于退出图形用户界面 GUI.
只生成了一个静态的实例对象. 构造方法是 private 的, 不允许在该类之外 new. 这是一个有意思的小技巧.

package xjx;

import java.awt.event.*;

public class ApplicationShutdown implements WindowListener, ActionListener {
   /**
    * Only one instance.
    */
	public static ApplicationShutdown applicationShutdown = new ApplicationShutdown();

	/**
	 *************************** 
	 * This constructor is private such that the only instance is generated here.
	 *************************** 
	 */
	private ApplicationShutdown() {
	}// Of ApplicationShutdown.

	/**
	 *************************** 
	 * Shutdown the system
	 *************************** 
	 */
	public void windowClosing(WindowEvent comeInWindowEvent) {
		System.exit(0);
	}// Of windowClosing.

	public void windowActivated(WindowEvent comeInWindowEvent) {
	}// Of windowActivated.

	public void windowClosed(WindowEvent comeInWindowEvent) {
	}// Of windowClosed.

	public void windowDeactivated(WindowEvent comeInWindowEvent) {
	}// Of windowDeactivated.

	public void windowDeiconified(WindowEvent comeInWindowEvent) {
	}// Of windowDeiconified.

	public void windowIconified(WindowEvent comeInWindowEvent) {
	}// Of windowIconified.

	public void windowOpened(WindowEvent comeInWindowEvent) {
	}// Of windowOpened.

	/**
    *************************
    *************************
    */
	public void actionPerformed(ActionEvent ee) {
		System.exit(0);
	}// Of actionPerformed.
}// Of class ApplicationShutdown

DialogCloser.java 用于关闭窗口, 而不是整个的 GUI.

package xjx;

import java.awt.*;
import java.awt.event.*;

public class DialogCloser extends WindowAdapter implements ActionListener {

	/**
	 * The dialog under control.
	 */
	private Dialog currentDialog;

	/**
	 *************************** 
	 * The first constructor.
	 *************************** 
	 */
	public DialogCloser() {
		super();
	}// Of the first constructor

	/**
	 *************************** 
	 * The second constructor.
	 * 
	 * @param paraDialog
	 *            the dialog under control
	 *************************** 
	 */
	public DialogCloser(Dialog paraDialog) {
		currentDialog = paraDialog;
	}// Of the second constructor

	/**
	 *************************** 
	 * Close the dialog which clicking the cross at the up-right corner of the window.
	 * 
	 * @param comeInWindowEvent
	 *            From it we can obtain which window sent the message because X
	 *            was used.
	 *************************** 
	 */
	public void windowClosing(WindowEvent paraWindowEvent) {
		paraWindowEvent.getWindow().dispose();
	}// Of windowClosing.

	/**
	 *************************** 
	 * Close the dialog while pushing an "OK" or "Cancel" button.
	 * 
	 * @param paraEvent
	 *            Not considered. 
	 *************************** 
	 */
	public void actionPerformed(ActionEvent paraEvent) {
		currentDialog.dispose();
	}// Of actionPerformed.
}// Of class DialogCloser

ErrorDialog.java 用于显示出错信息. 有了 GUI 我们可以不再使用 System.out.println.

package xjx;

import java.awt.*;

public class ErrorDialog extends Dialog {

	/**
	 * Serial uid. Not quite useful.
	 */
	private static final long serialVersionUID = 124535235L;

	/**
	 * The ONLY ErrorDialog.
	 */
	public static ErrorDialog errorDialog = new ErrorDialog();

	/**
	 * The label containing the message to display.
	 */
	private TextArea messageTextArea;

	/**
	 *************************** 
	 * Display an error dialog and respective error message. Like other dialogs,
	 * this constructor is private, such that users can use only one dialog,
	 * i.e., ErrorDialog.errorDialog to display message. This is helpful for
	 * saving space (only one dialog) since we may need "many" dialogs.
	 *************************** 
	 */
	private ErrorDialog() {
		// This dialog is module.
		super(GUICommon.mainFrame, "Error", true);

		// Prepare for the dialog.
		messageTextArea = new TextArea();

		Button okButton = new Button("OK");
		okButton.setSize(20, 10);
		okButton.addActionListener(new DialogCloser(this));
		Panel okPanel = new Panel();
		okPanel.setLayout(new FlowLayout());
		okPanel.add(okButton);

		// Add TextArea and Button
		setLayout(new BorderLayout());
		add(BorderLayout.CENTER, messageTextArea);
		add(BorderLayout.SOUTH, okPanel);

		setLocation(200, 200);
		setSize(500, 200);
		addWindowListener(new DialogCloser());
		setVisible(false);
	}// Of constructor

	/**
	 *************************** 
	 * set message.
	 * 
	 * @param paramMessage
	 *            the new message
	 *************************** 
	 */
	public void setMessageAndShow(String paramMessage) {
		messageTextArea.setText(paramMessage);
		setVisible(true);
	}// Of setTitleAndMessage
}// Of class ErrorDialog

GUICommon.java 存储一些公用变量.

package xjx;

import java.awt.*;
import javax.swing.*;

public class GUICommon extends Object {
	/**
	 * Only one main frame.
	 */
	public static Frame mainFrame = null;

	/**
	 * Only one main pane.
	 */
	public static JTabbedPane mainPane = null;

	/**
	 * For default project number.
	 */
	public static int currentProjectNumber = 0;

	/**
	 * Default font.
	 */
	public static final Font MY_FONT = new Font("Times New Romans", Font.PLAIN, 12);

	/**
	 * Default color
	 */
	public static final Color MY_COLOR = Color.lightGray;

	/**
	 *************************** 
	 * Set the main frame. This can be done only once at the initialzing stage.
	 * 
	 * @param paraFrame
	 *            the main frame of the GUI.
	 * @throws Exception
	 *             If the main frame is set more than once.
	 *************************** 
	 */
	public static void setFrame(Frame paraFrame) throws Exception {
		if (mainFrame == null) {
			mainFrame = paraFrame;
		} else {
			throw new Exception("The main frame can be set only ONCE!");
		} // Of if
	}// Of setFrame

	/**
	 *************************** 
	 * Set the main pane. This can be done only once at the initialzing stage.
	 * 
	 * @param paramPane
	 *            the main pane of the GUI.
	 * @throws Exception
	 *             If the main panel is set more than once.
	 *************************** 
	 */
	public static void setPane(JTabbedPane paramPane) throws Exception {
		if (mainPane == null) {
			mainPane = paramPane;
		} else {
			throw new Exception("The main panel can be set only ONCE!");
		} // Of if
	}// Of setPAne

}// Of class GUICommon

HelpDialog.java 显示帮助信息, 这样, 在主界面点击 Help 按钮时, 就会显示相关参数的说明. 其目的在于提高软件的易用性、可维护性.

package xjx;

import java.io.*;
import java.awt.*;
import java.awt.event.*;

public class HelpDialog extends Dialog implements ActionListener {
	/**
	 * Serial uid. Not quite useful.
	 */
	private static final long serialVersionUID = 3869415040299264995L;

	/**
	 *************************** 
	 * Display the help dialog.
	 * 
	 * @param paraTitle
	 *            the title of the dialog.
	 * @param paraFilename
	 *            the help file.
	 *************************** 
	 */
	public HelpDialog(String paraTitle, String paraFilename) {
		super(GUICommon.mainFrame, paraTitle, true);
		setBackground(GUICommon.MY_COLOR);

		TextArea displayArea = new TextArea("", 10, 10, TextArea.SCROLLBARS_VERTICAL_ONLY);
		displayArea.setEditable(false);
		String textToDisplay = "";
		try {
			RandomAccessFile helpFile = new RandomAccessFile(paraFilename, "r");
			String tempLine = helpFile.readLine();
			while (tempLine != null) {
				textToDisplay = textToDisplay + tempLine + "\n";
				tempLine = helpFile.readLine();
			}
			helpFile.close();
		} catch (IOException ee) {
			dispose();
			ErrorDialog.errorDialog.setMessageAndShow(ee.toString());
		}
		// Use this if you need to display Chinese. Consult the author for this
		// method.
		// textToDisplay = SimpleTools.GB2312ToUNICODE(textToDisplay);
		displayArea.setText(textToDisplay);
		displayArea.setFont(new Font("Times New Romans", Font.PLAIN, 14));

		Button okButton = new Button("OK");
		okButton.setSize(20, 10);
		okButton.addActionListener(new DialogCloser(this));
		Panel okPanel = new Panel();
		okPanel.setLayout(new FlowLayout());
		okPanel.add(okButton);

		// OK Button
		setLayout(new BorderLayout());
		add(BorderLayout.CENTER, displayArea);
		add(BorderLayout.SOUTH, okPanel);

		setLocation(120, 70);
		setSize(500, 400);
		addWindowListener(new DialogCloser());
		setVisible(false);
	}// Of constructor

	/**
	 ************************* 
	 * Simply set it visible.
	 ************************* 
	 */
	public void actionPerformed(ActionEvent ee) {
		setVisible(true);
	}// Of actionPerformed.
}// Of class HelpDialog

第 78 天: GUI (2. 数据读取控件)

DoubleField.java 用于接受实型值, 如果不能解释成实型值会报错. 这样可以把用户的低级错误扼杀在摇篮中.

package xjx_GUI;

import java.awt.*;
import java.awt.event.*;

public class DoubleField extends TextField implements FocusListener {

	/**
	 * Serial uid. Not quite useful.
	 */
	private static final long serialVersionUID = 363634723L;

	/**
	 * The value
	 */
	protected double doubleValue;

	/**
	 *************************** 
	 * Give it default values.
	 *************************** 
	 */
	public DoubleField() {
		this("5.13", 10);
	}// Of the first constructor

	/**
	 *************************** 
	 * Only specify the content.
	 * 
	 * @param paraString
	 *            The content of the field.
	 *************************** 
	 */
	public DoubleField(String paraString) {
		this(paraString, 10);
	}// Of the second constructor

	/**
	 *************************** 
	 * Only specify the width.
	 * 
	 * @param paraWidth
	 *            The width of the field.
	 *************************** 
	 */
	public DoubleField(int paraWidth) {
		this("5.13", paraWidth);
	}// Of the third constructor

	/**
	 *************************** 
	 * Specify the content and the width.
	 * 
	 * @param paraString
	 *            The content of the field.
	 * @param paraWidth
	 *            The width of the field.
	 *************************** 
	 */
	public DoubleField(String paraString, int paraWidth) {
		super(paraString, paraWidth);
		addFocusListener(this);
	}// Of the fourth constructor

	/**
	 ********************************** 
	 * Implement FocusListenter.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void focusGained(FocusEvent paraEvent) {
	}// Of focusGained

	/**
	 ********************************** 
	 * Implement FocusListenter.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void focusLost(FocusEvent paraEvent) {
		try {
			doubleValue = Double.parseDouble(getText());
		} catch (Exception ee) {
			ErrorDialog.errorDialog
					.setMessageAndShow("\"" + getText() + "\" Not a double. Please check.");
			requestFocus();
		} // Of try
	}// Of focusLost

	/**
	 ********************************** 
	 * Get the double value.
	 * 
	 * @return the double value.
	 ********************************** 
	 */
	public double getValue() {
		try {
			doubleValue = Double.parseDouble(getText());
		} catch (Exception ee) {
			ErrorDialog.errorDialog
					.setMessageAndShow("\"" + getText() + "\" Not a double. Please check.");
			requestFocus();
		} // Of try
		return doubleValue;
	}// Of getValue
}// Of class DoubleField

IntegeField.java 同理.

package xjx_GUI;

import java.awt.*;
import java.awt.event.*;

public class IntegeField extends TextField implements FocusListener {

	/**
	 * Serial uid. Not quite useful.
	 */
	private static final long serialVersionUID = -2462338973265150779L;

	/**
	 *************************** 
	 * Only specify the content.
	 *************************** 
	 */
	public IntegeField() {
		this("513");
	}// Of constructor

	/**
	 *************************** 
	 * Specify the content and the width.
	 * 
	 * @param paraString
	 *            The default value of the content.
	 * @param paraWidth 
	 * The width of the field.
	 *************************** 
	 */
	public IntegeField(String paraString, int paraWidth) {
		super(paraString, paraWidth);
		addFocusListener(this);
	}// Of constructor

	/**
	 *************************** 
	 * Only specify the content.
	 * 
	 * @param paraString
	 *            The given default string.
	 *************************** 
	 */
	public IntegeField(String paraString) {
		super(paraString);
		addFocusListener(this);
	}// Of constructor

	/**
	 *************************** 
	 * Only specify the width.
	 * 
	 * @param paraWidth
	 *            The width of the field.
	 *************************** 
	 */
	public IntegeField(int paraWidth) {
		super(paraWidth);
		setText("513");
		addFocusListener(this);
	}// Of constructor

	/**
	 ********************************** 
	 * Implement FocusListenter.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void focusGained(FocusEvent paraEvent) {
	}// Of focusGained

	/**
	 ********************************** 
	 * Implement FocusListenter.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void focusLost(FocusEvent paraEvent) {
		try {
			Integer.parseInt(getText());
			// System.out.println(tempInt);
		} catch (Exception ee) {
			ErrorDialog.errorDialog.setMessageAndShow("\"" + getText()
					+ "\"Not an integer. Please check.");
			requestFocus();
		}
	}// Of focusLost

	/**
	 ********************************** 
	 * Get the int value. Show error message if the content is not an int.
	 * 
	 * @return the int value.
	 ********************************** 
	 */
	public int getValue() {
		int tempInt = 0;
		try {
			tempInt = Integer.parseInt(getText());
		} catch (Exception ee) {
			ErrorDialog.errorDialog.setMessageAndShow("\"" + getText()
					+ "\" Not an int. Please check.");
			requestFocus();
		}
		return tempInt;
	}// Of getValue

}// Of class IntegerField


FilenameField.java 则需要借助于系统提供的 FileDialog.

package xjx_GUI;

import java.io.*;
import java.awt.*;
import java.awt.event.*;

public class FilenameField extends TextField implements ActionListener,
		FocusListener {
	/**
	 * Serial uid. Not quite useful.
	 */
	private static final long serialVersionUID = 4572287941606065298L;

	/**
	 *************************** 
	 * No special initialization..
	 *************************** 
	 */
	public FilenameField() {
		super();
		setText("");
		addFocusListener(this);
	}// Of constructor

	/**
	 *************************** 
	 * No special initialization.
	 * 
	 * @param paraWidth
	 *            The width of the .
	 *************************** 
	 */
	public FilenameField(int paraWidth) {
		super(paraWidth);
		setText("");
		addFocusListener(this);
	}// Of constructor

	/**
	 *************************** 
	 * No special initialization.
	 * 
	 * @param paraWidth
	 *            The width of the .
	 * @param paraText
	 *            The given initial text
	 *************************** 
	 */
	public FilenameField(int paraWidth, String paraText) {
		super(paraWidth);
		setText(paraText);
		addFocusListener(this);
	}// Of constructor

	/**
	 *************************** 
	 * No special initialization.
	 * 
	 * @param paraWidth
	 *            The width of the .
	 * @param paraText
	 *            The given initial text
	 *************************** 
	 */
	public FilenameField(String paraText, int paraWidth) {
		super(paraWidth);
		setText(paraText);
		addFocusListener(this);
	}// Of constructor

	/**
	 ********************************** 
	 * Avoid setting null or empty string.
	 * 
	 * @param paraText
	 *            The given text.
	 ********************************** 
	 */
	public void setText(String paraText) {
		if (paraText.trim().equals("")) {
			super.setText("unspecified");
		} else {
			super.setText(paraText.replace('\\', '/'));
		}//Of if
	}// Of setText

	/**
	 ********************************** 
	 * Implement ActionListenter.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void actionPerformed(ActionEvent paraEvent) {
		FileDialog tempDialog = new FileDialog(GUICommon.mainFrame,
				"Select a file");
		tempDialog.setVisible(true);
		if (tempDialog.getDirectory() == null) {
			setText("");
			return;
		}//Of if
		
		String directoryName = tempDialog.getDirectory();
		
		String tempFilename = directoryName + tempDialog.getFile(); 
		//System.out.println("tempFilename = " + tempFilename);

		setText(tempFilename);
	}// Of actionPerformed

	/**
	 ********************************** 
	 * Implement FocusListenter.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void focusGained(FocusEvent paraEvent) {
	}// Of focusGained

	/**
	 ********************************** 
	 * Implement FocusListenter.
	 * 
	 * @param paraEvent
	 *            The event is unimportant.
	 ********************************** 
	 */
	public void focusLost(FocusEvent paraEvent) {
		// System.out.println("Focus lost exists.");
		String tempString = getText();
		if ((tempString.equals("unspecified"))
				|| (tempString.equals("")))
			return;
		File tempFile = new File(tempString);
		if (!tempFile.exists()) {
			ErrorDialog.errorDialog.setMessageAndShow("File \"" + tempString
					+ "\" not exists. Please check.");
			requestFocus();
			setText("");
		}
	}// Of focusLost
}// Of class FilenameField


第 79 天: GUI (3. 总体布局)

在这里插入图片描述
1.用了 GridLayout 和 BorderLayout 来组织控件.
2.按下 OK 执行 actionPerformed.

package xjx_GUI;

import java.awt.*;
import java.awt.event.*;
import java.util.Date;

import xjx.FullAnn;

public class AnnMain implements ActionListener {
	/**
	 * Select the arff file.
	 */
	private FilenameField arffFilenameField;

	/**
	 * The setting of alpha.
	 */
	private DoubleField alphaField;

	/**
	 * The setting of alpha.
	 */
	private DoubleField betaField;

	/**
	 * The setting of alpha.
	 */
	private DoubleField gammaField;

	/**
	 * Layer nodes, such as "4, 8, 8, 3".
	 */
	private TextField layerNodesField;

	/**
	 * Activators, such as "ssa".
	 */
	private TextField activatorField;

	/**
	 * The number of training rounds.
	 */
	private IntegeField roundsField;

	/**
	 * The learning rate.
	 */
	private DoubleField learningRateField;

	/**
	 * The mobp.
	 */
	private DoubleField mobpField;

	/**
	 * The message area.
	 */
	private TextArea messageTextArea;

	/**
	 *************************** 
	 * The only constructor.
	 *************************** 
	 */
	public AnnMain() {
		// A simple frame to contain dialogs.
		Frame mainFrame = new Frame();
		mainFrame.setTitle("ANN");
		// The top part: select arff file.
		arffFilenameField = new FilenameField(30);
		arffFilenameField.setText("d:/data/iris.arff");
		Button browseButton = new Button(" Browse ");
		browseButton.addActionListener(new ActionListener() {
		});
		browseButton.addActionListener(new ActionListener() {
		});
		browseButton.addActionListener(arffFilenameField);

		Panel sourceFilePanel = new Panel();
		sourceFilePanel.add(new Label("The .arff file:"));
		sourceFilePanel.add(arffFilenameField);
		sourceFilePanel.add(browseButton);

		// Setting panel.
		Panel settingPanel = new Panel();
		settingPanel.setLayout(new GridLayout(3, 6));

		settingPanel.add(new Label("alpha"));
		alphaField = new DoubleField("0.01");
		settingPanel.add(alphaField);

		settingPanel.add(new Label("beta"));
		betaField = new DoubleField("0.02");
		settingPanel.add(betaField);

		settingPanel.add(new Label("gamma"));
		gammaField = new DoubleField("0.03");
		settingPanel.add(gammaField);

		settingPanel.add(new Label("layer nodes"));
		layerNodesField = new TextField("4, 8, 8, 3");
		settingPanel.add(layerNodesField);

		settingPanel.add(new Label("activators"));
		activatorField = new TextField("sss");
		settingPanel.add(activatorField);

		settingPanel.add(new Label("training rounds"));
		roundsField = new IntegerField("5000");
		settingPanel.add(roundsField);

		settingPanel.add(new Label("learning rate"));
		learningRateField = new DoubleField("0.01");
		settingPanel.add(learningRateField);

		settingPanel.add(new Label("mobp"));
		mobpField = new DoubleField("0.5");
		settingPanel.add(mobpField);

		Panel topPanel = new Panel();
		topPanel.setLayout(new BorderLayout());
		topPanel.add(BorderLayout.NORTH, sourceFilePanel);
		topPanel.add(BorderLayout.CENTER, settingPanel);

		messageTextArea = new TextArea(50, 40);

		// The bottom part: ok and exit
		Button okButton = new Button(" OK ");
		okButton.addActionListener(this);
		// DialogCloser dialogCloser = new DialogCloser(this);
		Button exitButton = new Button(" Exit ");
		// cancelButton.addActionListener(dialogCloser);
		exitButton.addActionListener(ApplicationShutdown.applicationShutdown);
		Button helpButton = new Button(" Help ");
		helpButton.setSize(20, 10);
		helpButton.addActionListener(new HelpDialog("ANN", "src/machinelearning/gui/help.txt"));
		Panel okPanel = new Panel();
		okPanel.add(okButton);
		okPanel.add(exitButton);
		okPanel.add(helpButton);

		mainFrame.setLayout(new BorderLayout());
		mainFrame.add(BorderLayout.NORTH, topPanel);
		mainFrame.add(BorderLayout.CENTER, messageTextArea);
		mainFrame.add(BorderLayout.SOUTH, okPanel);

		mainFrame.setSize(600, 500);
		mainFrame.setLocation(100, 100);
		mainFrame.addWindowListener(ApplicationShutdown.applicationShutdown);
		mainFrame.setBackground(GUICommon.MY_COLOR);
		mainFrame.setVisible(true);
	}// Of the constructor

	/**
	 *************************** 
	 * Read the arff file.
	 *************************** 
	 */
	public void actionPerformed(ActionEvent ae) {
		String tempFilename = arffFilenameField.getText();

		// Read the layers nodes.
		String tempString = layerNodesField.getText().trim();

		int[] tempLayerNodes = null;
		try {
			tempLayerNodes = stringToIntArray(tempString);
		} catch (Exception ee) {
			ErrorDialog.errorDialog.setMessageAndShow(ee.toString());
			return;
		} // Of try

		double tempLearningRate = learningRateField.getValue();
		double tempMobp = mobpField.getValue();
		String tempActivators = activatorField.getText().trim();
		FullAnn tempNetwork = new FullAnn(tempFilename, tempLayerNodes, tempLearningRate, tempMobp,
				tempActivators);
		int tempRounds = roundsField.getValue();

		long tempStartTime = new Date().getTime();
		for (int i = 0; i < tempRounds; i++) {
			tempNetwork.train();
		} // Of for n
		long tempEndTime = new Date().getTime();
		messageTextArea.append("\r\nSummary:\r\n");
		messageTextArea.append("Trainng time: " + (tempEndTime - tempStartTime) + "ms.\r\n");

		double tempAccuray = tempNetwork.test();
		messageTextArea.append("Accuracy: " + tempAccuray + "\r\n");
		messageTextArea.append("End.");
	}// Of actionPerformed

	/**
	 ********************************** 
	 * Convert a string with commas into an int array.
	 * 
	 * @param paraString
	 *            The source string
	 * @return An int array.
	 * @throws Exception
	 *             Exception for illegal data.
	 ********************************** 
	 */
	public static int[] stringToIntArray(String paraString) throws Exception {
		int tempCounter = 1;
		for (int i = 0; i < paraString.length(); i++) {
			if (paraString.charAt(i) == ',') {
				tempCounter++;
			} // Of if
		} // Of for i

		int[] resultArray = new int[tempCounter];

		String tempRemainingString = new String(paraString) + ",";
		String tempString;
		for (int i = 0; i < tempCounter; i++) {
			tempString = tempRemainingString.substring(0, tempRemainingString.indexOf(",")).trim();
			if (tempString.equals("")) {
				throw new Exception("Blank is unsupported");
			} // Of if

			resultArray[i] = Integer.parseInt(tempString);

			tempRemainingString = tempRemainingString
					.substring(tempRemainingString.indexOf(",") + 1);
		} // Of for i

		return resultArray;
	}// Of stringToIntArray

	/**
	 *************************** 
	 * The entrance method.
	 * 
	 * @param args
	 *            The parameters.
	 *************************** 
	 */
	public static void main(String args[]) {
		new AnnMain();
	}// Of main
}// Of class AnnMain

第 80 天: GUI (4. 各种监听机制)

1.从监听机制、接口等角度, 分析在 GUI 上的各种操作分别会触发哪些代码;
2.总结基础的人工神经网络.

1.Java事件监听机制

在上述的程序中,其中菜单条,菜单 项,按钮等都是对象,当我们单击对象时,应该能够完成一些任务.例如在程序中通过鼠标操作时,单击,双击,鼠标移入,鼠标移出.能够执行一些任务,在 Java中我们可以使用事件监听机制,在Java的事件监听机制中 ,当事件发生时(点击按钮,移动鼠标等,关闭窗口)会被一类对象发现并处理.

用户动作源对象触发的事件类型
点击按钮JButtonActionEvent
文本域按回车JTextFieldActionEvent
窗口打开,关闭,最小化,关闭WindowWindowEvent
单击,双击,移动,鼠标ComponentMouseEvent
点击单选框JradioButtonItemEvent ActionEvent
点击复选框JcheckBoxItemEvent ActionEvent

Java中,对象表示的每个事件都是由java.util中EventObject类的子类,
例如: MouseEvent: 表示鼠标的动作,例如移动光标,单击,双击
KeyEvent: 表示键盘上的按键.
ActionEvent表示用户采取的用户界面操作,例如点击屏幕上的按钮.

2.基础人工神经网络总结

BP神经网络:

迭代算法,随机设定初值,计算当前网络的输出,根据当前输出和lable直接的差去改变前面各层的参数,直到收敛;

缺点:
梯度越来越稀疏,从顶层越往下,误差校正信号越老越小;
收敛到局部最优,尤其是从远离最优区域开始的时候(随机值初始化导致);
一般只能用有标签的数据训练,但大部分数据没有标签;

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值