原博文:minfanphd
任务计划
第71天:BP神经网络基础类 (数据读取与基本结构)
以下资料来自于 《神经网络与深度学习》-邱锡鹏 一书
每一层的神经元可以接收前一层神经元的信号,并产生信号输出到下一层。第0 层称为输入层,最后一层称
为输出层,其他中间层称为隐藏层。
z
(
l
)
=
W
(
l
)
a
(
l
−
1
)
+
b
(
l
)
,
\begin{aligned} z^{(l)} = W^{(l)}a^{(l-1)}+b^{(l)}, \end{aligned}
z(l)=W(l)a(l−1)+b(l),
a
(
l
)
=
f
l
(
z
(
l
)
)
.
\begin{aligned} a^{(l)} = f_l(z^{(l)}). \end{aligned}
a(l)=fl(z(l)).
z
(
l
)
z^{(l)}
z(l) 表示第
l
l
l 层的净输入,也就是值没有经过激活函数的输入。
a
(
l
)
a^{(l)}
a(l) 则是指的经过激活函数后的输出。
W
,
b
W,b
W,b表示网络中所有层的连接权重和偏置。
package MachineLearning.ann;
import weka.core.Instances;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
/**
* @description:抽象类Ann
* @learner: Qing Zhang
* @time: 07
*/
public abstract class GeneralAnn {
//数据集
Instances dataset;
//层的数量
int numLayers;
//每层的节点数量,如[3, 4, 6, 2]表示输入层有三个节点,隐藏层分别有4个和6个节点,输出层有两个节点,二分类
int[] layerNumNodes;
//动量系数(Momentum coefficient)
public double mobp;
//学习率
public double learningRate;
//随机种子
Random random = new Random();
/**
* @Description: 构造函数
* @Param: [paraFileName, paraLayerNumNodes, paraLearningRate, paraMobp]
* @return:
*/
public GeneralAnn(String paraFileName, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp) {
try{
FileReader tempReader = new FileReader(paraFileName);
dataset = new Instances(tempReader);
dataset.setClassIndex(dataset.numAttributes()-1);
tempReader.close();
}catch (Exception ee){
System.out.println("Error occurred while trying to read \'" + paraFileName
+ "\' in GeneralAnn constructor.\r\n" + ee);
System.exit(0);
}
//接收参数
layerNumNodes = paraLayerNumNodes;
numLayers = layerNumNodes.length;
learningRate = paraLearningRate;
layerNumNodes[0] = dataset.numAttributes() - 1;
layerNumNodes[numLayers - 1] = dataset.numClasses();
mobp = paraMobp;
}
/**
* @Description: 前向预测
* @Param: [paraInput]
* @return: double[]
*/
public abstract double[] forward(double[] paraInput);
/**
* @Description: 反向传播
* @Param: [paraTarget]
* @return: void
*/
public abstract void backPropagation(double[] paraTarget);
/**
* @Description: 训练
* @Param: []
* @return: void
*/
public void train(){
double[] tempInput = new double[dataset.numAttributes() - 1];
double[] tempTarget = new double[dataset.numClasses()];
for (int i = 0; i < dataset.numInstances(); i++) {
//填充数据
for (int j = 0; j < tempInput.length; j++) {
tempInput[j] = dataset.instance(i).value(j);
}
//填充类标签
Arrays.fill(tempTarget, 0);
tempTarget[(int) dataset.instance(i).classValue()] = 1;
//使用该实例训练
forward(tempInput);
backPropagation(tempTarget);
}
}
/**
* @Description: 获取数组的最大值对应的索引
* @Param: [paraArray]
* @return: int
*/
public static int argmax(double[] paraArray) {
int resultIndex = -1;
double tempMax = -1e10;
for (int i = 0; i < paraArray.length; i++) {
if (tempMax < paraArray[i]) {
tempMax = paraArray[i];
resultIndex = i;
}
}
return resultIndex;
}
/**
* @Description: 使用数据集测试
* @Param: []
* @return: double
*/
public double test() {
double[] tempInput = new double[dataset.numAttributes() - 1];
double tempNumCorrect = 0;
double[] tempPrediction;
int tempPredictedClass = -1;
for (int i = 0; i < dataset.numInstances(); i++) {
//填充数据
for (int j = 0; j < tempInput.length; j++) {
tempInput[j] = dataset.instance(i).value(j);
}
//使用该实例训练
tempPrediction = forward(tempInput);
System.out.println("prediction: " + Arrays.toString(tempPrediction));
tempPredictedClass = argmax(tempPrediction);
if (tempPredictedClass == (int) dataset.instance(i).classValue()) {
tempNumCorrect++;
}
}
System.out.println("Correct: " + tempNumCorrect + " out of " + dataset.numInstances());
return tempNumCorrect / dataset.numInstances();
}
}
第72天:固定激活函数的BP神经网络 (1. 网络结构理解)
- layerNumNodes 表示网络基本结构. 如: [3, 4, 6, 2] 表示:
a) 输入端口有 3 个,即数据有 3 个条件属性. 如果与实际数据不符, 代码会自动纠正, 见 GeneralAnn.java 54 行.
b) 输出端口有 2 个, 即数据的决策类别数为 2. 如果与实际数据不符, 代码会自动纠正, 见 GeneralAnn.java 55 行. 对于分类问题, 数据是哪个类别, 对应于输出值最大的端口.
c) 有两个中间层(也就是隐藏层), 分别为 4 个和 6 个节点.
纠正的原因主要还是需要跟数据集一致,毕竟这里的参数是人为设置,那么可能会出现错误,因此根据数据集的实际情况做纠正会更加严谨。- layerNodeValues 表示各网络节点的值. 如上例, 网络的节点有 4 层, 即 layerNodeValues.length 为 4. 总结点数为 3 + 4 + 6 + 2 = 15 \mathbf{3 + 4 + 6 + 2 = 15} 3+4+6+2=15 个, 即 layerNodeValues[0].length = 3, layerNodeValues[1].length = 4, layerNodeValues[2].length = 6, layerNodeValues[3].length = 2. Java 支持这种不规则的矩阵 (不同行的列数不同), 因为二维矩阵被当作一维向量的一维向量.
- layerNodeErrors 表示各网络节点上的误差. 该数组大小于 layerNodeValues 一致.
- edgeWeights 表示各条边的权重. 由于两层之间的边为多对多关系 (二维数组), 多个层的边就成了三维数组. 例如, 上面例子的第 0 层就应该有 ( 3 + 1 ) × 4 = 16 \mathbf{( 3 + 1 ) \times 4 = 16} (3+1)×4=16 条边, 这里 + 1 \mathbf{+1} +1 表示有偏移量 offset. 总共的层数为 4 − 1 = 3 \mathbf{4 − 1 = 3} 4−1=3 , 即边的层数要比节点层数少 1. 这也是写程序过程中非常容易出错的地方.
- edgeWeightsDelta 与 edgeWeights 具有相同大小, 它辅助后者进行调整.
这里需要了解一下相应的优化函数,目前使用的是 momentum 动量法,具体的思想可以移步至这篇帖子
深度学习优化函数详解(4)-- momentum 动量法
下面是核心代码:
package MachineLearning.ann;
import weka.core.Instances;
import java.io.FileReader;
/**
* @description:
* @learner: Qing Zhang
* @time: 07
*/
public class SimpleAnn extends GeneralAnn {
//前向传播过程中每个节点变化的值。第一维表示层,第二维表示节点
public double[][] layerNodeValues;
//反向传播过程中每个节点变化的错误。第一维表示层,第二维表示节点
public double[][] layerNodeErrors;
//边的权值。第一维表示层,第二维表示该层的节点下标,第三维表示下一层的节点下标
public double[][][] edgeWeights;
//边的权值变化值。它的大小与edgeWeights相同
public double[][][] edgeWeightsDelta;
/**
* @Description: 构造函数
* @Param: [paraFileName, paraLayerNumNodes, paraLearningRate, paraMobp]
* @return:
*/
public SimpleAnn(String paraFileName, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp) {
super(paraFileName, paraLayerNumNodes, paraLearningRate, paraMobp);
//层层初始化
layerNodeValues = new double[numLayers][];
layerNodeErrors = new double[numLayers][];
edgeWeights = new double[numLayers - 1][][];
edgeWeightsDelta = new double[numLayers - 1][][];
//层内初始化
for (int l = 0; l < numLayers; l++) {
layerNodeValues[l] = new double[layerNumNodes[l]];
layerNodeErrors[l] = new double[layerNumNodes[l]];
//后面初始化边时需要少一层,因为每条边穿过两层
if (l + 1 == numLayers) {
break;
}
//在 layerNumNodes[l] + 1,最后一个为偏移保留。
edgeWeights[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];
edgeWeightsDelta[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];
for (int i = 0; i < layerNumNodes[l] + 1; i++) {
for (int j = 0; j < layerNumNodes[l + 1]; j++) {
//初始化权值
edgeWeights[l][i][j] = random.nextDouble();
}
}
}
}
@Override
public double[] forward(double[] paraInput) {
//初始化输入层
for (int i = 0; i < layerNodeValues[0].length; i++) {
layerNodeValues[0][i] = paraInput[i];
}
//计算每层的节点值
double z;
for (int l = 1; l < numLayers; l++) {
for (int j = 0; j < layerNodeValues[l].length; j++) {
//根据偏置初始化,偏置为 +1
//这里是先加上偏置
z = edgeWeights[l - 1][layerNodeValues[l - 1].length][j];
//将所有边的加权和给该节点使用
for (int i = 0; i < layerNodeValues[l - 1].length; i++) {
z += edgeWeights[l - 1][i][j] * layerNodeValues[l - 1][i];
}
//Sigmoid 激活函数
//对于其他激活函数,这一行应该更改。
layerNodeValues[l][j] = 1 / (1 + Math.exp(-z));
}
}
return layerNodeValues[numLayers - 1];
}
@Override
public void backPropagation(double[] paraTarget) {
//初始化输出层错误
int l = numLayers - 1;
for (int j = 0; j < layerNodeErrors[l].length; j++) {
layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * (paraTarget[j] - layerNodeValues[l][j]);
}
//反向传播直到 l==0
while (l > 0) {
l--;
//第l层的每个节点
for (int j = 0; j < layerNumNodes[l]; j++) {
double z = 0.0;
//下一层的每个节点
for (int i = 0; i < layerNumNodes[l + 1]; i++) {
if (l > 0) {
z += layerNodeErrors[l + 1][i] * edgeWeights[l][j][i];
}
//调整权值
edgeWeightsDelta[l][j][i] = mobp * edgeWeightsDelta[l][j][i] + learningRate * layerNodeErrors[l + 1][i] * layerNodeValues[l][j];
edgeWeights[l][j][i] += edgeWeightsDelta[l][j][i];
if (j == layerNumNodes[l] - 1) {
//调整偏移部分的权值
edgeWeightsDelta[l][j + 1][i] = mobp * edgeWeightsDelta[l][j + 1][i]
+ learningRate * layerNodeErrors[l + 1][i];
edgeWeights[l][j + 1][i] += edgeWeightsDelta[l][j + 1][i];
}
}
//根据Sigmoid的微分记录错误。
//对于其他激活函数,这一行应该更改。
layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * z;
}
}
}
public static void main(String[] args) {
int[] tempLayerNodes = {4, 8, 8, 3};
BPNeuralNetwork tempNetwork = new BPNeuralNetwork("F:\\研究生\\研0\\学习\\Java_Study\\data_set\\iris.arff", tempLayerNodes, 0.01,
0.6);
for (int round = 0; round < 5000; round++) {
tempNetwork.train();
}
double tempAccuray = tempNetwork.test();
System.out.println("The accuracy is: " + tempAccuray);
}
}
第73天:固定激活函数的BP神经网络 (2. 训练与测试过程理解)
- Forward 就是利用当前网络对一条数据进行预测的过程.
- BackPropagation 就是根据误差进行网络权重调节的过程.
- 训练的时候需要前向与后向, 测试的时候只需要前向.
- 这里只实现了 sigmoid 激活函数, 反向传播时的导数与正向传播时的激活函数相对应. 如果要换激活函数, 需要两个地方同时换.(这里需要重点去理解一下,因为后向结合了优化函数,因此需要根据相应的优化函数以及激活函数去调整代码)
第74天:通用BP神经网络 (1. 集中管理激活函数)
- 激活与求导是一个, 前者用于 forward, 后者用于 back-propagation.
- 有很多的激活函数, 它们的设计有相应准则, 如分段可导.
- 查资料补充几个未实现的激活函数.
- 进一步测试.
Sigmoid:
σ
(
x
)
=
1
1
+
e
−
x
\sigma(x) = \frac{1}{1+e^{-x}}
σ(x)=1+e−x1
Tanh:
σ
(
x
)
=
2
1
+
e
(
−
2
x
)
−
1
\sigma(x) = \frac{2}{1+e^{(-2x)}}-1
σ(x)=1+e(−2x)2−1
Arctan:
σ
(
x
)
=
arctan
(
x
)
\sigma(x) = \arctan(x)
σ(x)=arctan(x)
Elu:
σ
(
x
)
=
{
x
,
x
≥
0
α
(
e
x
−
1
)
,
x
<
0
\sigma(x) = \begin{cases} x,x\geq 0\\ \alpha(e^x-1), x<0 \end{cases}
σ(x)={x,x≥0α(ex−1),x<0
Identity:
σ
(
x
)
=
x
\sigma(x) = x
σ(x)=x
Soft Sign:
σ
(
x
)
=
{
x
1
+
x
,
x
≥
0
x
1
−
x
,
x
<
0
\sigma(x) = \begin{cases} \frac{x}{1+x},x\geq 0\\ \frac{x}{1-x}, x<0 \end{cases}
σ(x)={1+xx,x≥01−xx,x<0
Soft Plus:
σ
(
x
)
=
log
(
1
+
e
x
)
\sigma(x) = \log(1+e^x)
σ(x)=log(1+ex)
Relu:
σ
(
x
)
=
{
x
,
x
≥
0
0
,
x
<
0
\sigma(x) = \begin{cases} x,x\geq 0\\ 0, x<0 \end{cases}
σ(x)={x,x≥00,x<0
Leaky Relu:
σ
(
x
)
=
{
x
,
x
≥
0
α
x
,
x
<
0
\sigma(x) = \begin{cases} x,x\geq 0\\ \alpha x, x<0 \end{cases}
σ(x)={x,x≥0αx,x<0
图像源码:
from matplotlib import pyplot as plt
import numpy as np
import math
def sigmoid_function(x):
fz = []
for num in x:
fz.append(1 / (1 + math.exp(-num)))
return fz
def sigmoid_test():
x = np.arange(-10, 10, 0.01)
fz = sigmoid_function(x)
show_graph('Sigmoid Function', 'x', 'σ(x)', x, fz)
def tanh_function(x):
fz = []
for num in x:
fz.append(2 / (1 + math.exp(-2 * num)) - 1)
return fz
def tanh_test():
x = np.arange(-10, 10, 0.01)
fz = tanh_function(x)
show_graph('Tanh Function', 'x', 'σ(x)', x, fz)
def arctan_function(x):
fz = []
for num in x:
fz.append(math.atan(num))
return fz
def arctan_test():
x = np.arange(-50, 50, 0.01)
fz = arctan_function(x)
show_graph('Arctan Function', 'x', 'σ(x)', x, fz)
def elu_function(x, alpha):
fz = []
for num in x:
if num >= 0:
fz.append(num)
else:
fz.append(alpha * (math.exp(num) - 1))
return fz
def elu_test():
x = np.arange(-50, 50, 0.01)
fz = elu_function(x, 0.5)
show_graph('Elu Function', 'x', 'σ(x)', x, fz)
def identity_function(x):
fz = []
for num in x:
fz.append(num)
return fz
def identity_test():
x = np.arange(-10, 10, 0.01)
fz = identity_function(x)
show_graph('Identity Function', 'x', 'σ(x)', x, fz)
def leakyRelu_function(x, alpha):
fz = []
for num in x:
if num >= 0:
fz.append(num)
else:
fz.append(alpha * num)
return fz
def leakyRelu_test():
x = np.arange(-10, 10, 0.01)
alpha = 0.5
fz = leakyRelu_function(x, alpha)
show_graph('Leaky Relu Function', 'x', 'σ(x)', x, fz)
def softSign_function(x):
fz = []
for num in x:
if num >= 0:
fz.append(num / (1 + num))
else:
fz.append(num / (1 - num))
return fz
def softSign_test():
x = np.arange(-10, 10, 0.01)
fz = softSign_function(x)
show_graph('Soft Sign Function', 'x', 'σ(x)', x, fz)
def softPlus_function(x):
fz = []
for num in x:
fz.append(math.log(1 + math.exp(num)))
return fz
def softPlus_test():
x = np.arange(-10, 10, 0.01)
fz = softPlus_function(x)
show_graph('Soft Plus Function', 'x', 'σ(x)', x, fz)
def relu_function(x):
fz = []
for num in x:
if num >= 0:
fz.append(num)
else:
fz.append(0)
return fz
def relu_test():
x = np.arange(-10, 10, 0.01)
fz = relu_function(x)
show_graph('Relu Function', 'x', 'σ(x)', x, fz)
def show_graph(title, xlable, ylable, x, fz):
plt.title(title)
plt.xlabel(xlable)
plt.ylabel(ylable)
plt.plot(x, fz)
plt.show()
if __name__ == '__main__':
sigmoid_test()
tanh_test()
arctan_test()
elu_test()
identity_test()
softSign_test()
softPlus_test()
relu_test()
leakyRelu_test()
package MachineLearning.ann;
/**
* @description:激活函数
* @learner: Qing Zhang
* @time: 07
*/
public class Activator {
// Arc tan.
public final char ARC_TAN = 'a';
// Elu.
public final char ELU = 'e';
// Gelu.
public final char GELU = 'g';
// Hard logistic.
public final char HARD_LOGISTIC = 'h';
// Identity.
public final char IDENTITY = 'i';
// Leaky relu, also known as parametric relu.
public final char LEAKY_RELU = 'l';
// Relu.
public final char RELU = 'r';
// Soft sign.
public final char SOFT_SIGN = 'o';
// Sigmoid.
public final char SIGMOID = 's';
// Tanh.
public final char TANH = 't';
// Soft plus.
public final char SOFT_PLUS = 'u';
// Swish.
public final char SWISH = 'w';
// The activator.
private char activator;
// Alpha for elu.
double alpha;
// Beta for leaky relu.
double beta;
// Gamma for leaky relu.
double gamma;
/**
* @Description: 构造函数
* @Param: [paraActivator]
* @return:
*/
public Activator(char paraActivator) {
activator = paraActivator;
}
/**
* @Description: 设置
* @Param: [paraActivator]
* @return: void
*/
public void setActivator(char paraActivator) {
activator = paraActivator;
}
/**
* @Description: 获取
* @Param: []
* @return: char
*/
public char getActivator() {
return activator;
}
/**
* @Description: 设置α
* @Param: [paraAlpha]
* @return: void
*/
void setAlpha(double paraAlpha) {
alpha = paraAlpha;
}// Of setAlpha
/**
* @Description: 设置β
* @Param: [paraBeta]
* @return: void
*/
void setBeta(double paraBeta) {
beta = paraBeta;
}
/**
* @Description: 设置γ
* @Param: [paraGamma]
* @return: void
*/
void setGamma(double paraGamma) {
gamma = paraGamma;
}
/**
* @Description: 根据设置的激活函数激活
* @Param: [paraValue]
* @return: double
*/
public double activate(double paraValue) {
double resultValue = 0;
switch (activator) {
case ARC_TAN:
resultValue = Math.atan(paraValue);
break;
case ELU:
if (paraValue >= 0) {
resultValue = paraValue;
} else {
resultValue = alpha * (Math.exp(paraValue) - 1);
}
break;
// case GELU:
// resultValue = ?;
// break;
// case HARD_LOGISTIC:
// resultValue = ?;
// break;
case IDENTITY:
resultValue = paraValue;
break;
case LEAKY_RELU:
if (paraValue >= 0) {
resultValue = paraValue;
} else {
resultValue = alpha * paraValue;
}
break;
case SOFT_SIGN:
if (paraValue >= 0) {
resultValue = paraValue / (1 + paraValue);
} else {
resultValue = paraValue / (1 - paraValue);
}
break;
case SOFT_PLUS:
resultValue = Math.log(1 + Math.exp(paraValue));
break;
case RELU:
if (paraValue >= 0) {
resultValue = paraValue;
} else {
resultValue = 0;
}
break;
case SIGMOID:
resultValue = 1 / (1 + Math.exp(-paraValue));
break;
case TANH:
resultValue = 2 / (1 + Math.exp(-2 * paraValue)) - 1;
break;
// case SWISH:
// resultValue = ?;
// break;
default:
System.out.println("Unsupported activator: " + activator);
System.exit(0);
}
return resultValue;
}
/**
* @Description: 根据激活函数求导。有些使用x,有些使用f(x)
* @Param: [paraValue:x, paraActivatedValue:f(x)]
* @return: double
*/
public double derive(double paraValue, double paraActivatedValue) {
double resultValue = 0;
switch (activator) {
case ARC_TAN:
resultValue = 1 / (paraValue * paraValue + 1);
break;
case ELU:
if (paraValue >= 0) {
resultValue = 1;
} else {
resultValue = alpha * (Math.exp(paraValue) - 1) + alpha;
}
break;
// case GELU:
// resultValue = ?;
// break;
// case HARD_LOGISTIC:
// resultValue = ?;
// break;
case IDENTITY:
resultValue = 1;
break;
case LEAKY_RELU:
if (paraValue >= 0) {
resultValue = 1;
} else {
resultValue = alpha;
}
break;
case SOFT_SIGN:
if (paraValue >= 0) {
resultValue = 1 / (1 + paraValue) / (1 + paraValue);
} else {
resultValue = 1 / (1 - paraValue) / (1 - paraValue);
}
break;
case SOFT_PLUS:
resultValue = 1 / (1 + Math.exp(-paraValue));
break;
case RELU: // Updated
if (paraValue >= 0) {
resultValue = 1;
} else {
resultValue = 0;
}
break;
case SIGMOID: // Updated
resultValue = paraActivatedValue * (1 - paraActivatedValue);
break;
case TANH: // Updated
resultValue = 1 - paraActivatedValue * paraActivatedValue;
break;
// case SWISH:
// resultValue = ?;
// break;
default:
System.out.println("Unsupported activator: " + activator);
System.exit(0);
}
return resultValue;
}
public String toString() {
String resultString = "Activator with function '" + activator + "'";
resultString += "\r\n alpha = " + alpha + ", beta = " + beta + ", gamma = " + gamma;
return resultString;
}
public static void main(String[] args) {
Activator tempActivator = new Activator('s');
double tempValue = 0.6;
double tempNewValue;
tempNewValue = tempActivator.activate(tempValue);
System.out.println("After activation: " + tempNewValue);
tempNewValue = tempActivator.derive(tempValue, tempNewValue);
System.out.println("After derive: " + tempNewValue);
}
}
第75天:通用BP神经网络 (2. 单层实现)
- 仅实现单层 ANN.
- 可以有自己的激活函数.
- 正向计算输出, 反向计算误差并调整权值.
这里对单层的ANN进行了编码,同时进行测试,可以结合之前创建的 Activator 类调整激活函数。
package MachineLearning.ann;
import java.util.Arrays;
import java.util.Random;
/**
* @description: Ann层
* @learner: Qing Zhang
* @time: 07
*/
public class AnnLayer {
//输入数量
int numInput;
//输出数量
int numOutput;
//学习率
double learningRate;
//动量系数
double mobp;
//权值矩阵
double[][] weights, deltaWeights;
double[] offset, deltaOffset, errors;
//输入
double[] input;
//输出
double[] output;
//激活后的输出
double[] activatedOutput;
//输入
Activator activator;
//输入
Random random = new Random();
public AnnLayer(int paraNumInput, int paraNumOutput, char paraActivator, double paraLearningRate, double paraMobp) {
numInput = paraNumInput;
numOutput = paraNumOutput;
learningRate = paraLearningRate;
mobp = paraMobp;
weights = new double[numInput + 1][numOutput];
deltaWeights = new double[numInput + 1][numOutput];
for (int i = 0; i < numInput + 1; i++) {
for (int j = 0; j < numOutput; j++) {
weights[i][j] = random.nextDouble();
}
}
offset = new double[numOutput];
deltaOffset = new double[numOutput];
errors = new double[numInput];
input = new double[numInput];
output = new double[numOutput];
activatedOutput = new double[numOutput];
activator = new Activator(paraActivator);
}
/**
* @Description: 前向预测
* @Param: [paraInput]
* @return: double[]
*/
public double[] forward(double[] paraInput) {
//拷贝数据
for (int i = 0; i < numInput; i++) {
input[i] = paraInput[i];
}
//计算加权和以求得每个输出
for (int i = 0; i < numOutput; i++) {
output[i] = weights[numInput][i];
for (int j = 0; j < numInput; j++) {
output[i] += input[j] * weights[j][i];
}
activatedOutput[i] = activator.activate(output[i]);
}
return activatedOutput;
}
/**
* @Description: 反向传播并改变权值
* @Param: [paraInput]
* @return: double[]
*/
public double[] backPropagation(double[] paraErrors) {
//拷贝数据
for (int i = 0; i < paraErrors.length; i++) {
paraErrors[i] = activator.derive(output[i], activatedOutput[i]) * paraErrors[i];
}
//计算当前的错误
for (int i = 0; i < numInput; i++) {
errors[i] = 0;
for (int j = 0; j < numOutput; j++) {
errors[i] += paraErrors[j] * weights[i][j];
deltaWeights[i][j] = mobp * deltaWeights[i][j] + learningRate * paraErrors[j] * input[i];
weights[i][j] += deltaWeights[i][j];
if (i == numInput - 1) {
//调整偏置
deltaOffset[j] = mobp * deltaOffset[j] + learningRate * paraErrors[j];
offset[j] += deltaOffset[j];
}
}
}
return errors;
}
/**
* @Description: 获取最后一层的错误
* @Param: [paraTarget]
* @return: double[]
*/
public double[] getLastLayerErrors(double[] paraTarget) {
double[] resultErrors = new double[numOutput];
for (int i = 0; i < numOutput; i++) {
resultErrors[i] = (paraTarget[i] - activatedOutput[i]);
}
return resultErrors;
}
@Override
public String toString() {
String resultString = "";
resultString += "Activator: " + activator;
resultString += "\r\n weights = " + Arrays.deepToString(weights);
return resultString;
}
/**
* @Description: 单元测试
* @Param: []
* @return: void
*/
public static void unitTest() {
AnnLayer tempLayer = new AnnLayer(2, 3, 's', 0.01, 0.1);
double[] tempInput = {1, 4};
System.out.println(tempLayer);
double[] tempOutput = tempLayer.forward(tempInput);
System.out.println("Forward, the output is: " + Arrays.toString(tempOutput));
double[] tempError = tempLayer.backPropagation(tempOutput);
System.out.println("Back propagation, the error is: " + Arrays.toString(tempError));
}
public static void main(String[] args) {
unitTest();
}
}
第76天:通用BP神经网络 (3. 综合测试)
- 自己尝试其它的激活函数.
package MachineLearning.ann;
/**
* @description: 完整的神经网络
* @learner: Qing Zhang
* @time: 07
*/
public class FullAnn extends GeneralAnn {
AnnLayer[] layers;
public FullAnn(String paraFileName, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp, String paraActivators) {
super(paraFileName, paraLayerNumNodes, paraLearningRate, paraMobp);
//初始化层
layers = new AnnLayer[numLayers - 1];
for (int i = 0; i < layers.length; i++) {
layers[i] = new AnnLayer(layerNumNodes[i], layerNumNodes[i + 1], paraActivators.charAt(i), paraLearningRate, paraMobp);
}
}
@Override
public double[] forward(double[] paraInput) {
double[] resultArray = paraInput;
for (int i = 0; i < numLayers - 1; i++) {
resultArray = layers[i].forward(resultArray);
}
return resultArray;
}
@Override
public void backPropagation(double[] paraTarget) {
double[] tempErrors = layers[numLayers - 2].getLastLayerErrors(paraTarget);
for (int i = numLayers - 2; i >= 0; i--) {
tempErrors = layers[i].backPropagation(tempErrors);
}
}
@Override
public String toString() {
String resultString = "I am a full ANN with " + numLayers + " layers";
return resultString;
}
public static void main(String[] args) {
int[] tempLayerNodes = {4, 8, 8, 3};
FullAnn tempNetwork = new FullAnn("F:\\研究生\\研0\\学习\\Java_Study\\data_set\\iris.arff", tempLayerNodes, 0.01, 0.6, "sss");
for (int round = 0; round < 5000; round++) {
tempNetwork.train();
}
double tempAccuray = tempNetwork.test();
System.out.println("The accuracy is: " + tempAccuray);
System.out.println("FullAnn ends.");
}
}
Sigmoid函数:
SOFT_SIGN:
SOFT_PLUS:
RELU:
LEAKY_RELU:
ELU:
第77天:GUI (1. 对话框相关控件)
- ApplicationShowdown.java 仅用于退出图形用户界面 GUI.
- 只生成了一个静态的实例对象. 构造方法是 private 的, 不允许在该类之外 new. 这是一个有意思的小技巧.
package MachineLearning.gui;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowEvent;
import java.awt.event.WindowListener;
/**
* @description:通过窗口事件或者按钮事件关闭应用程序
* @learner: Qing Zhang
* @time: 07
*/
public class ApplicationShutdown implements WindowListener, ActionListener {
//只能存在一个对象
public static ApplicationShutdown applicationShutdown = new ApplicationShutdown();
//构造函数是私人的,因为只能存在一个对象,而静态对象已经声明了。
private ApplicationShutdown() {
}
//关闭系统
public void windowClosing(WindowEvent comeInWindowEvent) {
System.exit(0);
}// Of windowClosing.
public void windowActivated(WindowEvent comeInWindowEvent) {
}
public void windowClosed(WindowEvent comeInWindowEvent) {
}
public void windowDeactivated(WindowEvent comeInWindowEvent) {
}
public void windowDeiconified(WindowEvent comeInWindowEvent) {
}
public void windowIconified(WindowEvent comeInWindowEvent) {
}
public void windowOpened(WindowEvent comeInWindowEvent) {
}
public void actionPerformed(ActionEvent ee) {
System.exit(0);
}
}
DialogCloser.java 用于关闭窗口, 而不是整个的 GUI.
package MachineLearning.gui;
import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
/**
* @description:关闭当前窗口
* @learner: Qing Zhang
* @time: 07
*/
public class DialogCloser extends WindowAdapter implements ActionListener {
//当前打开的窗口
private Dialog currentDialog;
public DialogCloser() {
super();
}
public DialogCloser(Dialog paraDialog) {
currentDialog = paraDialog;
}// Of the second constructor
/**
* @Description: 关闭窗口
* 点击窗口右上角时
* @Param: [paraWindowEvent]
* @return: void
*/
public void windowClosing(WindowEvent paraWindowEvent) {
paraWindowEvent.getWindow().dispose();
}
/**
***************************
* Close the dialog while pushing an "OK" or "Cancel" button.
*
* @param paraEvent
* Not considered.
***************************
*/
/**
* @Description: 关闭窗口
* 当点击“OK”或者“Cancel”按钮时
* @Param: [paraEvent]
* @return: void
*/
public void actionPerformed(ActionEvent paraEvent) {
currentDialog.dispose();
}
}
ErrorDialog.java 用于显示出错信息. 有了 GUI 我们可以不再使用 System.out.println.
package MachineLearning.gui;
import java.awt.*;
/**
* @description:错误窗口
* @learner: Qing Zhang
* @time: 07
*/
public class ErrorDialog extends Dialog {
//Serial uid. 不一定有用
private static final long serialVersionUID = 124535235L;
//唯一的错误窗口
public static ErrorDialog errorDialog = new ErrorDialog();
//用于显示信息的标签文本
private TextArea messageTextArea;
/**
* @Description: 错误窗口
* 该窗口与其他窗口一样也只存在一个,这样可以节省内存,
* 当出现许多错误时,一个错误窗口即可解决
* @Param: []
* @return:
*/
private ErrorDialog() {
//模型窗口
super(GUICommon.mainFrame, "Error", true);
//初始化该窗口的内容
messageTextArea = new TextArea();
Button okButton = new Button("OK");
okButton.setSize(20, 10);
okButton.addActionListener(new DialogCloser(this));
Panel okPanel = new Panel();
okPanel.setLayout(new FlowLayout());
okPanel.add(okButton);
//添加文本域和按钮
setLayout(new BorderLayout());
add(BorderLayout.CENTER, messageTextArea);
add(BorderLayout.SOUTH, okPanel);
setLocation(200, 200);
setSize(500, 200);
addWindowListener(new DialogCloser());
setVisible(false);
}
/**
* @Description: 设置信息
* @Param: [paramMessage]
* @return: void
*/
public void setMessageAndShow(String paramMessage) {
messageTextArea.setText(paramMessage);
setVisible(true);
}
}
GUICommon.java 存储一些公用变量.
package MachineLearning.gui;
import javax.swing.*;
import java.awt.*;
/**
* @description:公共变量
* @learner: Qing Zhang
* @time: 07
*/
public class GUICommon extends Object {
//仅一个主窗口
public static Frame mainFrame = null;
//一个主布局
public static JTabbedPane mainPane = null;
//默认数量
public static int currentProjectNumber = 0;
//默认文字
public static final Font MY_FONT = new Font("Times New Romans", Font.PLAIN, 12);
//默认颜色
public static final Color MY_COLOR = Color.lightGray;
/**
* @Description: 设置主窗口。这一步骤仅在开始时执行一次
* @Param: [paraFrame]
* @return: void
*/
public static void setFrame(Frame paraFrame) throws Exception {
if (mainFrame == null) {
mainFrame = paraFrame;
} else {
throw new Exception("The main frame can be set only ONCE!");
}
}
/**
* @Description: 设置主布局。这一步骤仅在开始时执行一次
* @Param: [paramPane]
* @return: void
*/
public static void setPane(JTabbedPane paramPane) throws Exception {
if (mainPane == null) {
mainPane = paramPane;
} else {
throw new Exception("The main panel can be set only ONCE!");
}
}
}
HelpDialog.java 显示帮助信息, 这样, 在主界面点击 Help 按钮时, 就会显示相关参数的说明. 其目的在于提高软件的易用性、可维护性.
package MachineLearning.gui;
import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.IOException;
import java.io.RandomAccessFile;
/**
* @description:帮助框
* @learner: Qing Zhang
* @time: 07
*/
public class HelpDialog extends Dialog implements ActionListener {
/**
* Serial uid. Not quite useful.
*/
private static final long serialVersionUID = 3869415040299264995L;
/**
* @Description: 显示帮助窗口
* @Param: [paraTitle, paraFilename]
* @return:
*/
public HelpDialog(String paraTitle, String paraFilename) {
super(GUICommon.mainFrame, paraTitle, true);
setBackground(GUICommon.MY_COLOR);
TextArea displayArea = new TextArea("", 10, 10, TextArea.SCROLLBARS_VERTICAL_ONLY);
displayArea.setEditable(false);
String textToDisplay = "";
try {
RandomAccessFile helpFile = new RandomAccessFile(paraFilename, "r");
String tempLine = helpFile.readLine();
while (tempLine != null) {
textToDisplay = textToDisplay + tempLine + "\n";
tempLine = helpFile.readLine();
}
helpFile.close();
} catch (IOException ee) {
dispose();
ErrorDialog.errorDialog.setMessageAndShow(ee.toString());
}
// 如果需要显示中文就使用这个。用这个方法
// method.
// textToDisplay = SimpleTools.GB2312ToUNICODE(textToDisplay);
displayArea.setText(textToDisplay);
displayArea.setFont(new Font("Times New Romans", Font.PLAIN, 14));
Button okButton = new Button("OK");
okButton.setSize(20, 10);
okButton.addActionListener(new DialogCloser(this));
Panel okPanel = new Panel();
okPanel.setLayout(new FlowLayout());
okPanel.add(okButton);
// OK 按钮
setLayout(new BorderLayout());
add(BorderLayout.CENTER, displayArea);
add(BorderLayout.SOUTH, okPanel);
setLocation(120, 70);
setSize(500, 400);
addWindowListener(new DialogCloser());
setVisible(false);
}
/**
* @Description: 简单的激活使它可视化
* @Param: [ee]
* @return: void
*/
public void actionPerformed(ActionEvent ee) {
setVisible(true);
}
}
第78天:GUI (2. 数据读取控件)
DoubleField.java 用于接受实型值, 如果不能解释成实型值会报错. 这样可以把用户的低级错误扼杀在摇篮中.
package MachineLearning.gui;
import java.awt.*;
import java.awt.event.FocusEvent;
import java.awt.event.FocusListener;
/**
* @description:用于接收double值
* @learner: Qing Zhang
* @time: 07
*/
public class DoubleField extends TextField implements FocusListener {
//Serial uid. 不一定有用
private static final long serialVersionUID = 363634723L;
//值
protected double doubleValue;
//赋予默认值
public DoubleField() {
this("5.13", 10);
}// Of the first constructor
//只指定内容
public DoubleField(String paraString) {
this(paraString, 10);
}// Of the second constructor
//只指定宽
public DoubleField(int paraWidth) {
this("5.13", paraWidth);
}// Of the third constructor
/**
* @Description: 指定宽和长
* @Param: [paraString, paraWidth]
* @return:
*/
public DoubleField(String paraString, int paraWidth) {
super(paraString, paraWidth);
addFocusListener(this);
}
/**
* @Description:获得焦点事件
* @Param: [paraEvent]
* @return: void
*/
public void focusGained(FocusEvent paraEvent) {
}
/**
* @Description: 执行焦点的监听事件
* @Param: [paraEvent]
* @return: void
*/
public void focusLost(FocusEvent paraEvent) {
try {
doubleValue = Double.parseDouble(getText());
} catch (Exception ee) {
ErrorDialog.errorDialog
.setMessageAndShow("\"" + getText() + "\" Not a double. Please check.");
requestFocus();
}
}
/**
* @Description: 获取值
* @Param: []
* @return: double
*/
public double getValue() {
try {
doubleValue = Double.parseDouble(getText());
} catch (Exception ee) {
ErrorDialog.errorDialog
.setMessageAndShow("\"" + getText() + "\" Not a double. Please check.");
requestFocus();
}
return doubleValue;
}
}
IntegerField.java 同理.
package MachineLearning.gui;
import java.awt.*;
import java.awt.event.FocusEvent;
import java.awt.event.FocusListener;
/**
* @description: 用于接收int值
* @learner: Qing Zhang
* @time: 07
*/
public class IntegerField extends TextField implements FocusListener {
//Serial uid. 不一定有用
private static final long serialVersionUID = -2462338973265150779L;
//只指定内容
public IntegerField() {
this("513");
}// Of constructor
/**
* @Description: 指定宽和长
* @Param: [paraString, paraWidth]
* @return:
*/
public IntegerField(String paraString, int paraWidth) {
super(paraString, paraWidth);
addFocusListener(this);
}
//只指定内容
public IntegerField(String paraString) {
super(paraString);
addFocusListener(this);
}
//只指定宽
public IntegerField(int paraWidth) {
super(paraWidth);
setText("513");
addFocusListener(this);
}
/**
* @Description:获得焦点事件
* @Param: [paraEvent]
* @return: void
*/
public void focusGained(FocusEvent paraEvent) {
}
/**
* @Description: 执行焦点的监听事件
* @Param: [paraEvent]
* @return: void
*/
public void focusLost(FocusEvent paraEvent) {
try {
Integer.parseInt(getText());
// System.out.println(tempInt);
} catch (Exception ee) {
ErrorDialog.errorDialog.setMessageAndShow("\"" + getText()
+ "\"Not an integer. Please check.");
requestFocus();
}
}
/**
* @Description: 获取值
* @Param: []
* @return: int
*/
public int getValue() {
int tempInt = 0;
try {
tempInt = Integer.parseInt(getText());
} catch (Exception ee) {
ErrorDialog.errorDialog.setMessageAndShow("\"" + getText()
+ "\" Not an int. Please check.");
requestFocus();
}
return tempInt;
}
}
FilenameField.java 则需要借助于系统提供的 FileDialog.
package MachineLearning.gui;
import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.FocusEvent;
import java.awt.event.FocusListener;
import java.io.File;
/**
* @description:
* @learner: Qing Zhang
* @time: 07
*/
public class FilenameField extends TextField implements ActionListener, FocusListener {
//Serial uid. 不一定有用
private static final long serialVersionUID = 4572287941606065298L;
/**
* @Description: 初始化
* @Param: []
* @return:
*/
public FilenameField() {
super();
setText("");
addFocusListener(this);
}
/**
* @Description: 初始化
* @Param: [paraWidth]
* @return:
*/
public FilenameField(int paraWidth) {
super(paraWidth);
setText("");
addFocusListener(this);
}
/**
* @Description: 初始化
* @Param: [paraWidth, paraText]
* @return:
*/
public FilenameField(int paraWidth, String paraText) {
super(paraWidth);
setText(paraText);
addFocusListener(this);
}
/**
* @Description: 初始化
* @Param: [paraText, paraWidth]
* @return:
*/
public FilenameField(String paraText, int paraWidth) {
super(paraWidth);
setText(paraText);
addFocusListener(this);
}
/**
* @Description: 避免null或者空串
* @Param: [paraText]
* @return: void
*/
public void setText(String paraText) {
if (paraText.trim().equals("")) {
super.setText("unspecified");
} else {
super.setText(paraText.replace('\\', '/'));
}
}
/**
* @Description: 执行活动监听
* @Param: [paraEvent]
* @return: void
*/
public void actionPerformed(ActionEvent paraEvent) {
FileDialog tempDialog = new FileDialog(GUICommon.mainFrame,
"Select a file");
tempDialog.setVisible(true);
if (tempDialog.getDirectory() == null) {
setText("");
return;
}
String directoryName = tempDialog.getDirectory();
String tempFilename = directoryName + tempDialog.getFile();
//System.out.println("tempFilename = " + tempFilename);
setText(tempFilename);
}
/**
* @Description: 执行焦点监听事件
* @Param: [paraEvent]
* @return: void
*/
public void focusGained(FocusEvent paraEvent) {
}
/**
* @Description: 执行焦点监听事件
* @Param: [paraEvent]
* @return: void
*/
public void focusLost(FocusEvent paraEvent) {
// System.out.println("Focus lost exists.");
String tempString = getText();
if ((tempString.equals("unspecified"))
|| (tempString.equals("")))
return;
File tempFile = new File(tempString);
if (!tempFile.exists()) {
ErrorDialog.errorDialog.setMessageAndShow("File \"" + tempString
+ "\" not exists. Please check.");
requestFocus();
setText("");
}
}
}
第79天:GUI (3. 总体布局)
- 用了 GridLayout 和 BorderLayout 来组织控件.
- 按下 OK 执行 actionPerformed. 前两天已经有类似代码了.
package MachineLearning.gui;
import MachineLearning.ann.FullAnn;
import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.util.Date;
/**
* @description:
* @learner: Qing Zhang
* @time: 07
*/
public class AnnMain implements ActionListener {
//选择arff文件
private FilenameField arffFilenameField;
//设置α
private DoubleField alphaField;
//设置β
private DoubleField betaField;
//设置γ
private DoubleField gammaField;
//每层节点,如 "4, 8, 8, 3".
private TextField layerNodesField;
//激活函数的选择,例如 "ssa".
private TextField activatorField;
//训练次数
private IntegerField roundsField;
//学习率
private DoubleField learningRateField;
//mobp
private DoubleField mobpField;
//信息区域
private TextArea messageTextArea;
/**
* @Description: 唯一的构造函数
* @Param: []
* @return:
*/
public AnnMain() {
//一个简单的窗口包含对话框
Frame mainFrame = new Frame();
mainFrame.setTitle("ANN. minfanphd@163.com");
//顶部:选择arff文件
arffFilenameField = new FilenameField(30);
arffFilenameField.setText("d:/data/iris.arff");
Button browseButton = new Button(" Browse ");
browseButton.addActionListener(arffFilenameField);
Panel sourceFilePanel = new Panel();
sourceFilePanel.add(new Label("The .arff file:"));
sourceFilePanel.add(arffFilenameField);
sourceFilePanel.add(browseButton);
//设置面板
Panel settingPanel = new Panel();
settingPanel.setLayout(new GridLayout(3, 6));
settingPanel.add(new Label("alpha"));
alphaField = new DoubleField("0.01");
settingPanel.add(alphaField);
settingPanel.add(new Label("beta"));
betaField = new DoubleField("0.02");
settingPanel.add(betaField);
settingPanel.add(new Label("gamma"));
gammaField = new DoubleField("0.03");
settingPanel.add(gammaField);
settingPanel.add(new Label("layer nodes"));
layerNodesField = new TextField("4, 8, 8, 3");
settingPanel.add(layerNodesField);
settingPanel.add(new Label("activators"));
activatorField = new TextField("sss");
settingPanel.add(activatorField);
settingPanel.add(new Label("training rounds"));
roundsField = new IntegerField("5000");
settingPanel.add(roundsField);
settingPanel.add(new Label("learning rate"));
learningRateField = new DoubleField("0.01");
settingPanel.add(learningRateField);
settingPanel.add(new Label("mobp"));
mobpField = new DoubleField("0.5");
settingPanel.add(mobpField);
Panel topPanel = new Panel();
topPanel.setLayout(new BorderLayout());
topPanel.add(BorderLayout.NORTH, sourceFilePanel);
topPanel.add(BorderLayout.CENTER, settingPanel);
messageTextArea = new TextArea(80, 40);
//底部:ok和exit
Button okButton = new Button(" OK ");
okButton.addActionListener(this);
// DialogCloser dialogCloser = new DialogCloser(this);
Button exitButton = new Button(" Exit ");
// cancelButton.addActionListener(dialogCloser);
exitButton.addActionListener(ApplicationShutdown.applicationShutdown);
Button helpButton = new Button(" Help ");
helpButton.setSize(20, 10);
helpButton.addActionListener(new HelpDialog("ANN", "src/machinelearning/gui/help.txt"));
Panel okPanel = new Panel();
okPanel.add(okButton);
okPanel.add(exitButton);
okPanel.add(helpButton);
mainFrame.setLayout(new BorderLayout());
mainFrame.add(BorderLayout.NORTH, topPanel);
mainFrame.add(BorderLayout.CENTER, messageTextArea);
mainFrame.add(BorderLayout.SOUTH, okPanel);
mainFrame.setSize(600, 500);
mainFrame.setLocation(100, 100);
mainFrame.addWindowListener(ApplicationShutdown.applicationShutdown);
mainFrame.setBackground(GUICommon.MY_COLOR);
mainFrame.setVisible(true);
}
/**
* @Description: 读入arff文件
* @Param: [ae]
* @return: void
*/
public void actionPerformed(ActionEvent ae) {
String tempFilename = arffFilenameField.getText();
// Read the layers nodes.
String tempString = layerNodesField.getText().trim();
int[] tempLayerNodes = null;
try {
tempLayerNodes = stringToIntArray(tempString);
} catch (Exception ee) {
ErrorDialog.errorDialog.setMessageAndShow(ee.toString());
return;
}
double tempLearningRate = learningRateField.getValue();
double tempMobp = mobpField.getValue();
String tempActivators = activatorField.getText().trim();
FullAnn tempNetwork = new FullAnn(tempFilename, tempLayerNodes, tempLearningRate, tempMobp,
tempActivators);
int tempRounds = roundsField.getValue();
long tempStartTime = new Date().getTime();
for (int i = 0; i < tempRounds; i++) {
tempNetwork.train();
}
long tempEndTime = new Date().getTime();
messageTextArea.append("\r\nSummary:\r\n");
messageTextArea.append("Trainng time: " + (tempEndTime - tempStartTime) + "ms.\r\n");
double tempAccuray = tempNetwork.test();
messageTextArea.append("Accuracy: " + tempAccuray + "\r\n");
messageTextArea.append("End.");
}
/**
* @Description: 将带逗号的字符串转换为int数组。
* @Param: [paraString]
* @return: int[]
*/
public static int[] stringToIntArray(String paraString) throws Exception {
int tempCounter = 1;
for (int i = 0; i < paraString.length(); i++) {
if (paraString.charAt(i) == ',') {
tempCounter++;
}
}
int[] resultArray = new int[tempCounter];
String tempRemainingString = new String(paraString) + ",";
String tempString;
for (int i = 0; i < tempCounter; i++) {
tempString = tempRemainingString.substring(0, tempRemainingString.indexOf(",")).trim();
if (tempString.equals("")) {
throw new Exception("Blank is unsupported");
}
resultArray[i] = Integer.parseInt(tempString);
tempRemainingString = tempRemainingString
.substring(tempRemainingString.indexOf(",") + 1);
}
return resultArray;
}
public static void main(String args[]) {
new AnnMain();
}// Of main
}
第80天:GUI (4. 各种监听机制)
- 从监听机制、接口等角度, 分析在 GUI 上的各种操作分别会触发哪些代码;
由于之前用C#写过winform程序,所以对于GUI上的事件响应,监听机制还是比较熟悉的,这里主要是使用了观察者设计模式,事件源注册事件监听器后,当事件源上发生某个动作时,事件源就会调用事件监听的一个方法,并将事件对象传递进去,开发者可以利用事件对象操作事件源。比如当操作鼠标点击某个部件时,可以将鼠标的点击事件触发,从而传递消息,比如是否点击以及鼠标在窗体上的位置等信息。
- 总结基础的人工神经网络.
迭代算法,随机设定参数的初始值,计算当前网络的输出,根据当前输出与样本决策标签的误差再反向传播,改变参数值,不断循环往复直至收敛至某一阈值。
缺点:
- 不知道你的神经网络将会如何产出结果,更不知道为什么会产生这种结果。
- 比较耗时;
- 难以找到大量有标签的数据;