Java学习之日撸代码300行(71-80天,BP 神经网络)

本文详细介绍了BP神经网络的基础知识,包括网络结构、激活函数、训练与测试过程。此外,还展示了如何在Java中实现通用的BP神经网络,包括单层网络和全连接网络,并探讨了不同的激活函数。最后,文章提到了GUI界面的构建,包括数据读取控件和整体布局,以增强用户交互体验。
摘要由CSDN通过智能技术生成

原博文:minfanphd

第71天:BP神经网络基础类 (数据读取与基本结构)

以下资料来自于 《神经网络与深度学习》-邱锡鹏 一书

在这里插入图片描述

每一层的神经元可以接收前一层神经元的信号,并产生信号输出到下一层。第0 层称为输入层,最后一层称
为输出层,其他中间层称为隐藏层。

z ( l ) = W ( l ) a ( l − 1 ) + b ( l ) , \begin{aligned} z^{(l)} = W^{(l)}a^{(l-1)}+b^{(l)}, \end{aligned} z(l)=W(l)a(l1)+b(l),
a ( l ) = f l ( z ( l ) ) . \begin{aligned} a^{(l)} = f_l(z^{(l)}). \end{aligned} a(l)=fl(z(l)).

z ( l ) z^{(l)} z(l) 表示第 l l l 层的净输入,也就是值没有经过激活函数的输入。
a ( l ) a^{(l)} a(l) 则是指的经过激活函数后的输出。
W , b W,b W,b表示网络中所有层的连接权重和偏置。

package MachineLearning.ann;

import weka.core.Instances;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;

/**
 * @description:抽象类Ann
 * @learner: Qing Zhang
 * @time: 07
 */
public abstract class GeneralAnn {
    //数据集
    Instances dataset;

    //层的数量
    int numLayers;

    //每层的节点数量,如[3, 4, 6, 2]表示输入层有三个节点,隐藏层分别有4个和6个节点,输出层有两个节点,二分类
    int[] layerNumNodes;

    //动量系数(Momentum coefficient)
    public double mobp;

    //学习率
    public double learningRate;

    //随机种子
    Random random = new Random();

    /** 
    * @Description: 构造函数
    * @Param: [paraFileName, paraLayerNumNodes, paraLearningRate, paraMobp]
    * @return: 
    */
    public GeneralAnn(String paraFileName, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp) {
        try{
            FileReader tempReader = new FileReader(paraFileName);
            dataset = new Instances(tempReader);
            dataset.setClassIndex(dataset.numAttributes()-1);
            tempReader.close();
        }catch (Exception ee){
            System.out.println("Error occurred while trying to read \'" + paraFileName
                    + "\' in GeneralAnn constructor.\r\n" + ee);
            System.exit(0);
        }
        
        //接收参数
        layerNumNodes = paraLayerNumNodes;
        numLayers = layerNumNodes.length;
        learningRate = paraLearningRate;
        layerNumNodes[0] = dataset.numAttributes() - 1;
        layerNumNodes[numLayers - 1] = dataset.numClasses();
        mobp = paraMobp;
    }

    /**
    * @Description: 前向预测
    * @Param: [paraInput]
    * @return: double[]
    */
    public abstract double[] forward(double[] paraInput);



    /**
    * @Description: 反向传播
    * @Param: [paraTarget]
    * @return: void
    */
    public abstract void backPropagation(double[] paraTarget);


    /**
    * @Description: 训练
    * @Param: []
    * @return: void
    */
    public void train(){
        double[] tempInput = new double[dataset.numAttributes() - 1];
        double[] tempTarget = new double[dataset.numClasses()];
        for (int i = 0; i < dataset.numInstances(); i++) {
            //填充数据
            for (int j = 0; j < tempInput.length; j++) {
                tempInput[j] = dataset.instance(i).value(j);
            }

            //填充类标签
            Arrays.fill(tempTarget, 0);
            tempTarget[(int) dataset.instance(i).classValue()] = 1;

            //使用该实例训练
            forward(tempInput);
            backPropagation(tempTarget);
        }
    }

    /**
     * @Description: 获取数组的最大值对应的索引
     * @Param: [paraArray]
     * @return: int
     */
    public static int argmax(double[] paraArray) {
        int resultIndex = -1;
        double tempMax = -1e10;
        for (int i = 0; i < paraArray.length; i++) {
            if (tempMax < paraArray[i]) {
                tempMax = paraArray[i];
                resultIndex = i;
            }
        }

        return resultIndex;
    }

    /**
     * @Description: 使用数据集测试
     * @Param: []
     * @return: double
     */
    public double test() {
        double[] tempInput = new double[dataset.numAttributes() - 1];

        double tempNumCorrect = 0;
        double[] tempPrediction;
        int tempPredictedClass = -1;

        for (int i = 0; i < dataset.numInstances(); i++) {
            //填充数据
            for (int j = 0; j < tempInput.length; j++) {
                tempInput[j] = dataset.instance(i).value(j);
            }

            //使用该实例训练
            tempPrediction = forward(tempInput);
            System.out.println("prediction: " + Arrays.toString(tempPrediction));
            tempPredictedClass = argmax(tempPrediction);
            if (tempPredictedClass == (int) dataset.instance(i).classValue()) {
                tempNumCorrect++;
            }
        }

        System.out.println("Correct: " + tempNumCorrect + " out of " + dataset.numInstances());

        return tempNumCorrect / dataset.numInstances();
    }
}

第72天:固定激活函数的BP神经网络 (1. 网络结构理解)

  1. layerNumNodes 表示网络基本结构. 如: [3, 4, 6, 2] 表示:
    a) 输入端口有 3 个,即数据有 3 个条件属性. 如果与实际数据不符, 代码会自动纠正, 见 GeneralAnn.java 54 行.
    b) 输出端口有 2 个, 即数据的决策类别数为 2. 如果与实际数据不符, 代码会自动纠正, 见 GeneralAnn.java 55 行. 对于分类问题, 数据是哪个类别, 对应于输出值最大的端口.
    c) 有两个中间层(也就是隐藏层), 分别为 4 个和 6 个节点.
    纠正的原因主要还是需要跟数据集一致,毕竟这里的参数是人为设置,那么可能会出现错误,因此根据数据集的实际情况做纠正会更加严谨。
  2. layerNodeValues 表示各网络节点的值. 如上例, 网络的节点有 4 层, 即 layerNodeValues.length 为 4. 总结点数为 3 + 4 + 6 + 2 = 15 \mathbf{3 + 4 + 6 + 2 = 15} 3+4+6+2=15 个, 即 layerNodeValues[0].length = 3, layerNodeValues[1].length = 4, layerNodeValues[2].length = 6, layerNodeValues[3].length = 2. Java 支持这种不规则的矩阵 (不同行的列数不同), 因为二维矩阵被当作一维向量的一维向量.
  3. layerNodeErrors 表示各网络节点上的误差. 该数组大小于 layerNodeValues 一致.
  4. edgeWeights 表示各条边的权重. 由于两层之间的边为多对多关系 (二维数组), 多个层的边就成了三维数组. 例如, 上面例子的第 0 层就应该有 ( 3 + 1 ) × 4 = 16 \mathbf{( 3 + 1 ) \times 4 = 16} (3+1)×4=16 条边, 这里 + 1 \mathbf{+1} +1 表示有偏移量 offset. 总共的层数为 4 − 1 = 3 \mathbf{4 − 1 = 3} 41=3 , 即边的层数要比节点层数少 1. 这也是写程序过程中非常容易出错的地方.
  5. edgeWeightsDelta 与 edgeWeights 具有相同大小, 它辅助后者进行调整.

这里需要了解一下相应的优化函数,目前使用的是 momentum 动量法,具体的思想可以移步至这篇帖子
深度学习优化函数详解(4)-- momentum 动量法

下面是核心代码:

package MachineLearning.ann;

import weka.core.Instances;

import java.io.FileReader;

/**
 * @description:
 * @learner: Qing Zhang
 * @time: 07
 */
public class SimpleAnn extends GeneralAnn {


    //前向传播过程中每个节点变化的值。第一维表示层,第二维表示节点
    public double[][] layerNodeValues;

    //反向传播过程中每个节点变化的错误。第一维表示层,第二维表示节点
    public double[][] layerNodeErrors;

    //边的权值。第一维表示层,第二维表示该层的节点下标,第三维表示下一层的节点下标
    public double[][][] edgeWeights;

    //边的权值变化值。它的大小与edgeWeights相同
    public double[][][] edgeWeightsDelta;

    /**
     * @Description: 构造函数
     * @Param: [paraFileName, paraLayerNumNodes, paraLearningRate, paraMobp]
     * @return:
     */
    public SimpleAnn(String paraFileName, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp) {
        super(paraFileName, paraLayerNumNodes, paraLearningRate, paraMobp);

        //层层初始化
        layerNodeValues = new double[numLayers][];
        layerNodeErrors = new double[numLayers][];
        edgeWeights = new double[numLayers - 1][][];
        edgeWeightsDelta = new double[numLayers - 1][][];

        //层内初始化
        for (int l = 0; l < numLayers; l++) {
            layerNodeValues[l] = new double[layerNumNodes[l]];
            layerNodeErrors[l] = new double[layerNumNodes[l]];

            //后面初始化边时需要少一层,因为每条边穿过两层
            if (l + 1 == numLayers) {
                break;
            }

            //在 layerNumNodes[l] + 1,最后一个为偏移保留。
            edgeWeights[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];
            edgeWeightsDelta[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];

            for (int i = 0; i < layerNumNodes[l] + 1; i++) {
                for (int j = 0; j < layerNumNodes[l + 1]; j++) {
                    //初始化权值
                    edgeWeights[l][i][j] = random.nextDouble();
                }
            }
        }
    }


    @Override
    public double[] forward(double[] paraInput) {
        //初始化输入层
        for (int i = 0; i < layerNodeValues[0].length; i++) {
            layerNodeValues[0][i] = paraInput[i];
        }

        //计算每层的节点值
        double z;
        for (int l = 1; l < numLayers; l++) {
            for (int j = 0; j < layerNodeValues[l].length; j++) {
                //根据偏置初始化,偏置为 +1
                //这里是先加上偏置
                z = edgeWeights[l - 1][layerNodeValues[l - 1].length][j];
                //将所有边的加权和给该节点使用
                for (int i = 0; i < layerNodeValues[l - 1].length; i++) {
                    z += edgeWeights[l - 1][i][j] * layerNodeValues[l - 1][i];
                }

                //Sigmoid 激活函数
                //对于其他激活函数,这一行应该更改。
                layerNodeValues[l][j] = 1 / (1 + Math.exp(-z));
            }
        }
        return layerNodeValues[numLayers - 1];
    }

    @Override
    public void backPropagation(double[] paraTarget) {
        //初始化输出层错误
        int l = numLayers - 1;
        for (int j = 0; j < layerNodeErrors[l].length; j++) {
            layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * (paraTarget[j] - layerNodeValues[l][j]);
        }

        //反向传播直到 l==0
        while (l > 0) {
            l--;
            //第l层的每个节点
            for (int j = 0; j < layerNumNodes[l]; j++) {
                double z = 0.0;
                //下一层的每个节点
                for (int i = 0; i < layerNumNodes[l + 1]; i++) {
                    if (l > 0) {
                        z += layerNodeErrors[l + 1][i] * edgeWeights[l][j][i];
                    }

                    //调整权值
                    edgeWeightsDelta[l][j][i] = mobp * edgeWeightsDelta[l][j][i] + learningRate * layerNodeErrors[l + 1][i] * layerNodeValues[l][j];
                    edgeWeights[l][j][i] += edgeWeightsDelta[l][j][i];

                    if (j == layerNumNodes[l] - 1) {
                        //调整偏移部分的权值
                        edgeWeightsDelta[l][j + 1][i] = mobp * edgeWeightsDelta[l][j + 1][i]
                                + learningRate * layerNodeErrors[l + 1][i];
                        edgeWeights[l][j + 1][i] += edgeWeightsDelta[l][j + 1][i];
                    }
                }

                //根据Sigmoid的微分记录错误。
                //对于其他激活函数,这一行应该更改。
                layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * z;
            }
        }
    }

    public static void main(String[] args) {
        int[] tempLayerNodes = {4, 8, 8, 3};
        BPNeuralNetwork tempNetwork = new BPNeuralNetwork("F:\\研究生\\研0\\学习\\Java_Study\\data_set\\iris.arff", tempLayerNodes, 0.01,
                0.6);

        for (int round = 0; round < 5000; round++) {
            tempNetwork.train();
        }

        double tempAccuray = tempNetwork.test();
        System.out.println("The accuracy is: " + tempAccuray);
    }
}

第73天:固定激活函数的BP神经网络 (2. 训练与测试过程理解)

  1. Forward 就是利用当前网络对一条数据进行预测的过程.
  2. BackPropagation 就是根据误差进行网络权重调节的过程.
  3. 训练的时候需要前向与后向, 测试的时候只需要前向.
  4. 这里只实现了 sigmoid 激活函数, 反向传播时的导数与正向传播时的激活函数相对应. 如果要换激活函数, 需要两个地方同时换.(这里需要重点去理解一下,因为后向结合了优化函数,因此需要根据相应的优化函数以及激活函数去调整代码)

第74天:通用BP神经网络 (1. 集中管理激活函数)

  1. 激活与求导是一个, 前者用于 forward, 后者用于 back-propagation.
  2. 有很多的激活函数, 它们的设计有相应准则, 如分段可导.
  3. 查资料补充几个未实现的激活函数.
  4. 进一步测试.

Sigmoid:
σ ( x ) = 1 1 + e − x \sigma(x) = \frac{1}{1+e^{-x}} σ(x)=1+ex1
在这里插入图片描述


Tanh:
σ ( x ) = 2 1 + e ( − 2 x ) − 1 \sigma(x) = \frac{2}{1+e^{(-2x)}}-1 σ(x)=1+e(2x)21
在这里插入图片描述


Arctan:
σ ( x ) = arctan ⁡ ( x ) \sigma(x) = \arctan(x) σ(x)=arctan(x)
在这里插入图片描述


Elu:
σ ( x ) = { x , x ≥ 0 α ( e x − 1 ) , x < 0 \sigma(x) = \begin{cases} x,x\geq 0\\ \alpha(e^x-1), x<0 \end{cases} σ(x)={x,x0α(ex1),x<0
在这里插入图片描述


Identity:
σ ( x ) = x \sigma(x) = x σ(x)=x
在这里插入图片描述


Soft Sign:
σ ( x ) = { x 1 + x , x ≥ 0 x 1 − x , x < 0 \sigma(x) = \begin{cases} \frac{x}{1+x},x\geq 0\\ \frac{x}{1-x}, x<0 \end{cases} σ(x)={1+xx,x01xx,x<0
在这里插入图片描述


Soft Plus:
σ ( x ) = log ⁡ ( 1 + e x ) \sigma(x) = \log(1+e^x) σ(x)=log(1+ex)
在这里插入图片描述


Relu:
σ ( x ) = { x , x ≥ 0 0 , x < 0 \sigma(x) = \begin{cases} x,x\geq 0\\ 0, x<0 \end{cases} σ(x)={x,x00,x<0
在这里插入图片描述


Leaky Relu:
σ ( x ) = { x , x ≥ 0 α x , x < 0 \sigma(x) = \begin{cases} x,x\geq 0\\ \alpha x, x<0 \end{cases} σ(x)={x,x0αx,x<0

图像源码:

from matplotlib import pyplot as plt
import numpy as np
import math


def sigmoid_function(x):
    fz = []
    for num in x:
        fz.append(1 / (1 + math.exp(-num)))
    return fz


def sigmoid_test():
    x = np.arange(-10, 10, 0.01)
    fz = sigmoid_function(x)
    show_graph('Sigmoid Function', 'x', 'σ(x)', x, fz)


def tanh_function(x):
    fz = []
    for num in x:
        fz.append(2 / (1 + math.exp(-2 * num)) - 1)
    return fz

def tanh_test():
    x = np.arange(-10, 10, 0.01)
    fz = tanh_function(x)
    show_graph('Tanh Function', 'x', 'σ(x)', x, fz)


def arctan_function(x):
    fz = []
    for num in x:
        fz.append(math.atan(num))
    return fz


def arctan_test():
    x = np.arange(-50, 50, 0.01)
    fz = arctan_function(x)
    show_graph('Arctan Function', 'x', 'σ(x)', x, fz)


def elu_function(x, alpha):
    fz = []
    for num in x:
        if num >= 0:
            fz.append(num)
        else:
            fz.append(alpha * (math.exp(num) - 1))
    return fz


def elu_test():
    x = np.arange(-50, 50, 0.01)
    fz = elu_function(x, 0.5)
    show_graph('Elu Function', 'x', 'σ(x)', x, fz)


def identity_function(x):
    fz = []
    for num in x:
        fz.append(num)
    return fz


def identity_test():
    x = np.arange(-10, 10, 0.01)
    fz = identity_function(x)
    show_graph('Identity Function', 'x', 'σ(x)', x, fz)


def leakyRelu_function(x, alpha):
    fz = []
    for num in x:
        if num >= 0:
            fz.append(num)
        else:
            fz.append(alpha * num)
    return fz


def leakyRelu_test():
    x = np.arange(-10, 10, 0.01)
    alpha = 0.5
    fz = leakyRelu_function(x, alpha)
    show_graph('Leaky Relu Function', 'x', 'σ(x)', x, fz)


def softSign_function(x):
    fz = []
    for num in x:
        if num >= 0:
            fz.append(num / (1 + num))
        else:
            fz.append(num / (1 - num))
    return fz


def softSign_test():
    x = np.arange(-10, 10, 0.01)
    fz = softSign_function(x)
    show_graph('Soft Sign Function', 'x', 'σ(x)', x, fz)


def softPlus_function(x):
    fz = []
    for num in x:
        fz.append(math.log(1 + math.exp(num)))
    return fz


def softPlus_test():
    x = np.arange(-10, 10, 0.01)
    fz = softPlus_function(x)
    show_graph('Soft Plus Function', 'x', 'σ(x)', x, fz)


def relu_function(x):
    fz = []
    for num in x:
        if num >= 0:
            fz.append(num)
        else:
            fz.append(0)
    return fz


def relu_test():
    x = np.arange(-10, 10, 0.01)
    fz = relu_function(x)
    show_graph('Relu Function', 'x', 'σ(x)', x, fz)


def show_graph(title, xlable, ylable, x, fz):
    plt.title(title)
    plt.xlabel(xlable)
    plt.ylabel(ylable)
    plt.plot(x, fz)
    plt.show()


if __name__ == '__main__':
    sigmoid_test()
    tanh_test()
    arctan_test()
    elu_test()
    identity_test()
    softSign_test()
    softPlus_test()
    relu_test()
    leakyRelu_test()


package MachineLearning.ann;

/**
 * @description:激活函数
 * @learner: Qing Zhang
 * @time: 07
 */
public class Activator {
    // Arc tan.
    public final char ARC_TAN = 'a';

    // Elu.
    public final char ELU = 'e';

    // Gelu.
    public final char GELU = 'g';

    // Hard logistic.
    public final char HARD_LOGISTIC = 'h';

    // Identity.
    public final char IDENTITY = 'i';

    // Leaky relu, also known as parametric relu.
    public final char LEAKY_RELU = 'l';

    // Relu.
    public final char RELU = 'r';

    // Soft sign.
    public final char SOFT_SIGN = 'o';

    // Sigmoid.
    public final char SIGMOID = 's';

    // Tanh.
    public final char TANH = 't';

    // Soft plus.
    public final char SOFT_PLUS = 'u';

    // Swish.
    public final char SWISH = 'w';

    // The activator.
    private char activator;

    // Alpha for elu.
    double alpha;

    // Beta for leaky relu.
    double beta;

    // Gamma for leaky relu.
    double gamma;

    /**
    * @Description: 构造函数
    * @Param: [paraActivator]
    * @return:
    */
    public Activator(char paraActivator) {
        activator = paraActivator;
    }

    /**
    * @Description: 设置
    * @Param: [paraActivator]
    * @return: void
    */
    public void setActivator(char paraActivator) {
        activator = paraActivator;
    }

    /**
    * @Description: 获取
    * @Param: []
    * @return: char
    */
    public char getActivator() {
        return activator;
    }

    /**
    * @Description: 设置α
    * @Param: [paraAlpha]
    * @return: void
    */
    void setAlpha(double paraAlpha) {
        alpha = paraAlpha;
    }// Of setAlpha

    /**
    * @Description: 设置β
    * @Param: [paraBeta]
    * @return: void
    */
    void setBeta(double paraBeta) {
        beta = paraBeta;
    }

    /**
    * @Description: 设置γ
    * @Param: [paraGamma]
    * @return: void
    */
    void setGamma(double paraGamma) {
        gamma = paraGamma;
    }

    /**
    * @Description: 根据设置的激活函数激活
    * @Param: [paraValue]
    * @return: double
    */
    public double activate(double paraValue) {
        double resultValue = 0;
        switch (activator) {
            case ARC_TAN:
                resultValue = Math.atan(paraValue);
                break;
            case ELU:
                if (paraValue >= 0) {
                    resultValue = paraValue;
                } else {
                    resultValue = alpha * (Math.exp(paraValue) - 1);
                }
                break;
            // case GELU:
            // resultValue = ?;
            // break;
            // case HARD_LOGISTIC:
            // resultValue = ?;
            // break;
            case IDENTITY:
                resultValue = paraValue;
                break;
            case LEAKY_RELU:
                if (paraValue >= 0) {
                    resultValue = paraValue;
                } else {
                    resultValue = alpha * paraValue;
                }
                break;
            case SOFT_SIGN:
                if (paraValue >= 0) {
                    resultValue = paraValue / (1 + paraValue);
                } else {
                    resultValue = paraValue / (1 - paraValue);
                }
                break;
            case SOFT_PLUS:
                resultValue = Math.log(1 + Math.exp(paraValue));
                break;
            case RELU:
                if (paraValue >= 0) {
                    resultValue = paraValue;
                } else {
                    resultValue = 0;
                }
                break;
            case SIGMOID:
                resultValue = 1 / (1 + Math.exp(-paraValue));
                break;
            case TANH:
                resultValue = 2 / (1 + Math.exp(-2 * paraValue)) - 1;
                break;
            // case SWISH:
            // resultValue = ?;
            // break;
            default:
                System.out.println("Unsupported activator: " + activator);
                System.exit(0);
        }

        return resultValue;
    }

    /**
    * @Description: 根据激活函数求导。有些使用x,有些使用f(x)
    * @Param: [paraValue:x, paraActivatedValue:f(x)]
    * @return: double
    */
    public double derive(double paraValue, double paraActivatedValue) {
        double resultValue = 0;
        switch (activator) {
            case ARC_TAN:
                resultValue = 1 / (paraValue * paraValue + 1);
                break;
            case ELU:
                if (paraValue >= 0) {
                    resultValue = 1;
                } else {
                    resultValue = alpha * (Math.exp(paraValue) - 1) + alpha;
                }
                break;
            // case GELU:
            // resultValue = ?;
            // break;
            // case HARD_LOGISTIC:
            // resultValue = ?;
            // break;
            case IDENTITY:
                resultValue = 1;
                break;
            case LEAKY_RELU:
                if (paraValue >= 0) {
                    resultValue = 1;
                } else {
                    resultValue = alpha;
                }
                break;
            case SOFT_SIGN:
                if (paraValue >= 0) {
                    resultValue = 1 / (1 + paraValue) / (1 + paraValue);
                } else {
                    resultValue = 1 / (1 - paraValue) / (1 - paraValue);
                }
                break;
            case SOFT_PLUS:
                resultValue = 1 / (1 + Math.exp(-paraValue));
                break;
            case RELU: // Updated
                if (paraValue >= 0) {
                    resultValue = 1;
                } else {
                    resultValue = 0;
                }
                break;
            case SIGMOID: // Updated
                resultValue = paraActivatedValue * (1 - paraActivatedValue);
                break;
            case TANH: // Updated
                resultValue = 1 - paraActivatedValue * paraActivatedValue;
                break;
            // case SWISH:
            // resultValue = ?;
            // break;
            default:
                System.out.println("Unsupported activator: " + activator);
                System.exit(0);
        }

        return resultValue;
    }


    public String toString() {
        String resultString = "Activator with function '" + activator + "'";
        resultString += "\r\n alpha = " + alpha + ", beta = " + beta + ", gamma = " + gamma;

        return resultString;
    }


    public static void main(String[] args) {
        Activator tempActivator = new Activator('s');
        double tempValue = 0.6;
        double tempNewValue;
        tempNewValue = tempActivator.activate(tempValue);
        System.out.println("After activation: " + tempNewValue);

        tempNewValue = tempActivator.derive(tempValue, tempNewValue);
        System.out.println("After derive: " + tempNewValue);
    }

}


在这里插入图片描述

第75天:通用BP神经网络 (2. 单层实现)

  1. 仅实现单层 ANN.
  2. 可以有自己的激活函数.
  3. 正向计算输出, 反向计算误差并调整权值.

这里对单层的ANN进行了编码,同时进行测试,可以结合之前创建的 Activator 类调整激活函数。

package MachineLearning.ann;

import java.util.Arrays;
import java.util.Random;

/**
 * @description: Ann层
 * @learner: Qing Zhang
 * @time: 07
 */
public class AnnLayer {

    //输入数量
    int numInput;

    //输出数量
    int numOutput;

    //学习率
    double learningRate;

    //动量系数
    double mobp;

    //权值矩阵
    double[][] weights, deltaWeights;

    double[] offset, deltaOffset, errors;

    //输入
    double[] input;

    //输出
    double[] output;

    //激活后的输出
    double[] activatedOutput;

    //输入
    Activator activator;

    //输入
    Random random = new Random();

    public AnnLayer(int paraNumInput, int paraNumOutput, char paraActivator, double paraLearningRate, double paraMobp) {
        numInput = paraNumInput;
        numOutput = paraNumOutput;
        learningRate = paraLearningRate;
        mobp = paraMobp;

        weights = new double[numInput + 1][numOutput];
        deltaWeights = new double[numInput + 1][numOutput];
        for (int i = 0; i < numInput + 1; i++) {
            for (int j = 0; j < numOutput; j++) {
                weights[i][j] = random.nextDouble();
            }
        }

        offset = new double[numOutput];
        deltaOffset = new double[numOutput];
        errors = new double[numInput];

        input = new double[numInput];
        output = new double[numOutput];
        activatedOutput = new double[numOutput];

        activator = new Activator(paraActivator);
    }

    /**
     * @Description: 前向预测
     * @Param: [paraInput]
     * @return: double[]
     */
    public double[] forward(double[] paraInput) {
        //拷贝数据
        for (int i = 0; i < numInput; i++) {
            input[i] = paraInput[i];
        }

        //计算加权和以求得每个输出
        for (int i = 0; i < numOutput; i++) {
            output[i] = weights[numInput][i];
            for (int j = 0; j < numInput; j++) {
                output[i] += input[j] * weights[j][i];
            }

            activatedOutput[i] = activator.activate(output[i]);
        }

        return activatedOutput;
    }


    /**
     * @Description: 反向传播并改变权值
     * @Param: [paraInput]
     * @return: double[]
     */
    public double[] backPropagation(double[] paraErrors) {
        //拷贝数据
        for (int i = 0; i < paraErrors.length; i++) {
            paraErrors[i] = activator.derive(output[i], activatedOutput[i]) * paraErrors[i];
        }

        //计算当前的错误
        for (int i = 0; i < numInput; i++) {
            errors[i] = 0;
            for (int j = 0; j < numOutput; j++) {
                errors[i] += paraErrors[j] * weights[i][j];
                deltaWeights[i][j] = mobp * deltaWeights[i][j] + learningRate * paraErrors[j] * input[i];
                weights[i][j] += deltaWeights[i][j];

                if (i == numInput - 1) {
                    //调整偏置
                    deltaOffset[j] = mobp * deltaOffset[j] + learningRate * paraErrors[j];
                    offset[j] += deltaOffset[j];
                }
            }
        }

        return errors;
    }

    /**
     * @Description: 获取最后一层的错误
     * @Param: [paraTarget]
     * @return: double[]
     */
    public double[] getLastLayerErrors(double[] paraTarget) {
        double[] resultErrors = new double[numOutput];
        for (int i = 0; i < numOutput; i++) {
            resultErrors[i] = (paraTarget[i] - activatedOutput[i]);
        }

        return resultErrors;
    }

    @Override
    public String toString() {
        String resultString = "";
        resultString += "Activator: " + activator;
        resultString += "\r\n weights = " + Arrays.deepToString(weights);
        return resultString;
    }

    /**
     * @Description: 单元测试
     * @Param: []
     * @return: void
     */
    public static void unitTest() {
        AnnLayer tempLayer = new AnnLayer(2, 3, 's', 0.01, 0.1);
        double[] tempInput = {1, 4};

        System.out.println(tempLayer);

        double[] tempOutput = tempLayer.forward(tempInput);
        System.out.println("Forward, the output is: " + Arrays.toString(tempOutput));

        double[] tempError = tempLayer.backPropagation(tempOutput);
        System.out.println("Back propagation, the error is: " + Arrays.toString(tempError));
    }

    public static void main(String[] args) {
        unitTest();
    }
}

在这里插入图片描述

第76天:通用BP神经网络 (3. 综合测试)

  1. 自己尝试其它的激活函数.
package MachineLearning.ann;

/**
 * @description: 完整的神经网络
 * @learner: Qing Zhang
 * @time: 07
 */
public class FullAnn extends GeneralAnn {

    AnnLayer[] layers;


    public FullAnn(String paraFileName, int[] paraLayerNumNodes, double paraLearningRate, double paraMobp, String paraActivators) {
        super(paraFileName, paraLayerNumNodes, paraLearningRate, paraMobp);

        //初始化层
        layers = new AnnLayer[numLayers - 1];
        for (int i = 0; i < layers.length; i++) {
            layers[i] = new AnnLayer(layerNumNodes[i], layerNumNodes[i + 1], paraActivators.charAt(i), paraLearningRate, paraMobp);
        }
    }

    @Override
    public double[] forward(double[] paraInput) {
        double[] resultArray = paraInput;
        for (int i = 0; i < numLayers - 1; i++) {
            resultArray = layers[i].forward(resultArray);
        }
        return resultArray;
    }

    @Override
    public void backPropagation(double[] paraTarget) {
        double[] tempErrors = layers[numLayers - 2].getLastLayerErrors(paraTarget);
        for (int i = numLayers - 2; i >= 0; i--) {
            tempErrors = layers[i].backPropagation(tempErrors);
        }
    }

    @Override
    public String toString() {
        String resultString = "I am a full ANN with " + numLayers + " layers";
        return resultString;
    }

    public static void main(String[] args) {
        int[] tempLayerNodes = {4, 8, 8, 3};
        FullAnn tempNetwork = new FullAnn("F:\\研究生\\研0\\学习\\Java_Study\\data_set\\iris.arff", tempLayerNodes, 0.01, 0.6, "sss");

        for (int round = 0; round < 5000; round++) {
            tempNetwork.train();
        }

        double tempAccuray = tempNetwork.test();
        System.out.println("The accuracy is: " + tempAccuray);
        System.out.println("FullAnn ends.");
    }
}

Sigmoid函数:
在这里插入图片描述
SOFT_SIGN:
在这里插入图片描述

SOFT_PLUS:
在这里插入图片描述

RELU:
在这里插入图片描述
LEAKY_RELU:

在这里插入图片描述
ELU:
在这里插入图片描述

第77天:GUI (1. 对话框相关控件)

  1. ApplicationShowdown.java 仅用于退出图形用户界面 GUI.
  2. 只生成了一个静态的实例对象. 构造方法是 private 的, 不允许在该类之外 new. 这是一个有意思的小技巧.
package MachineLearning.gui;

import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowEvent;
import java.awt.event.WindowListener;

/**
 * @description:通过窗口事件或者按钮事件关闭应用程序
 * @learner: Qing Zhang
 * @time: 07
 */
public class ApplicationShutdown implements WindowListener, ActionListener {

    //只能存在一个对象
    public static ApplicationShutdown applicationShutdown = new ApplicationShutdown();


    //构造函数是私人的,因为只能存在一个对象,而静态对象已经声明了。
    private ApplicationShutdown() {
    }


    //关闭系统
    public void windowClosing(WindowEvent comeInWindowEvent) {
        System.exit(0);
    }// Of windowClosing.

    public void windowActivated(WindowEvent comeInWindowEvent) {
    }

    public void windowClosed(WindowEvent comeInWindowEvent) {
    }

    public void windowDeactivated(WindowEvent comeInWindowEvent) {
    }

    public void windowDeiconified(WindowEvent comeInWindowEvent) {
    }

    public void windowIconified(WindowEvent comeInWindowEvent) {
    }

    public void windowOpened(WindowEvent comeInWindowEvent) {
    }


    public void actionPerformed(ActionEvent ee) {
        System.exit(0);
    }

}

DialogCloser.java 用于关闭窗口, 而不是整个的 GUI.

package MachineLearning.gui;

import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;

/**
 * @description:关闭当前窗口
 * @learner: Qing Zhang
 * @time: 07
 */
public class DialogCloser extends WindowAdapter implements ActionListener {

    //当前打开的窗口
    private Dialog currentDialog;


    public DialogCloser() {
        super();
    }


    public DialogCloser(Dialog paraDialog) {
        currentDialog = paraDialog;
    }// Of the second constructor

    /**
    * @Description: 关闭窗口
     * 点击窗口右上角时
    * @Param: [paraWindowEvent]
    * @return: void
    */
    public void windowClosing(WindowEvent paraWindowEvent) {
        paraWindowEvent.getWindow().dispose();
    }

    /**
     ***************************
     * Close the dialog while pushing an "OK" or "Cancel" button.
     *
     * @param paraEvent
     *            Not considered.
     ***************************
     */
    /**
    * @Description: 关闭窗口
     * 当点击“OK”或者“Cancel”按钮时
    * @Param: [paraEvent]
    * @return: void
    */
    public void actionPerformed(ActionEvent paraEvent) {
        currentDialog.dispose();
    }
}

ErrorDialog.java 用于显示出错信息. 有了 GUI 我们可以不再使用 System.out.println.

package MachineLearning.gui;

import java.awt.*;

/**
 * @description:错误窗口
 * @learner: Qing Zhang
 * @time: 07
 */
public class ErrorDialog extends Dialog {

    //Serial uid. 不一定有用
    private static final long serialVersionUID = 124535235L;

    //唯一的错误窗口
    public static ErrorDialog errorDialog = new ErrorDialog();


    //用于显示信息的标签文本
    private TextArea messageTextArea;


    /**
    * @Description: 错误窗口
     * 该窗口与其他窗口一样也只存在一个,这样可以节省内存,
     * 当出现许多错误时,一个错误窗口即可解决
    * @Param: []
    * @return:
    */
    private ErrorDialog() {
        //模型窗口
        super(GUICommon.mainFrame, "Error", true);

        //初始化该窗口的内容
        messageTextArea = new TextArea();

        Button okButton = new Button("OK");
        okButton.setSize(20, 10);
        okButton.addActionListener(new DialogCloser(this));
        Panel okPanel = new Panel();
        okPanel.setLayout(new FlowLayout());
        okPanel.add(okButton);

        //添加文本域和按钮
        setLayout(new BorderLayout());
        add(BorderLayout.CENTER, messageTextArea);
        add(BorderLayout.SOUTH, okPanel);

        setLocation(200, 200);
        setSize(500, 200);
        addWindowListener(new DialogCloser());
        setVisible(false);
    }

    /**
    * @Description: 设置信息
    * @Param: [paramMessage]
    * @return: void
    */
    public void setMessageAndShow(String paramMessage) {
        messageTextArea.setText(paramMessage);
        setVisible(true);
    }
}

GUICommon.java 存储一些公用变量.

package MachineLearning.gui;

import javax.swing.*;
import java.awt.*;

/**
 * @description:公共变量
 * @learner: Qing Zhang
 * @time: 07
 */
public class GUICommon extends Object {

    //仅一个主窗口
    public static Frame mainFrame = null;


    //一个主布局
    public static JTabbedPane mainPane = null;

    //默认数量
    public static int currentProjectNumber = 0;

    //默认文字
    public static final Font MY_FONT = new Font("Times New Romans", Font.PLAIN, 12);

    //默认颜色
    public static final Color MY_COLOR = Color.lightGray;


    /** 
    * @Description: 设置主窗口。这一步骤仅在开始时执行一次
    * @Param: [paraFrame]
    * @return: void
    */
    public static void setFrame(Frame paraFrame) throws Exception {
        if (mainFrame == null) {
            mainFrame = paraFrame;
        } else {
            throw new Exception("The main frame can be set only ONCE!");
        }
    }

    
    
    /** 
    * @Description: 设置主布局。这一步骤仅在开始时执行一次
    * @Param: [paramPane]
    * @return: void
    */
    public static void setPane(JTabbedPane paramPane) throws Exception {
        if (mainPane == null) {
            mainPane = paramPane;
        } else {
            throw new Exception("The main panel can be set only ONCE!");
        }
    }
}

HelpDialog.java 显示帮助信息, 这样, 在主界面点击 Help 按钮时, 就会显示相关参数的说明. 其目的在于提高软件的易用性、可维护性.

package MachineLearning.gui;

import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.IOException;
import java.io.RandomAccessFile;

/**
 * @description:帮助框
 * @learner: Qing Zhang
 * @time: 07
 */
public class HelpDialog extends Dialog implements ActionListener {
    /**
     * Serial uid. Not quite useful.
     */
    private static final long serialVersionUID = 3869415040299264995L;


    /**
    * @Description: 显示帮助窗口
    * @Param: [paraTitle, paraFilename]
    * @return:
    */
    public HelpDialog(String paraTitle, String paraFilename) {
        super(GUICommon.mainFrame, paraTitle, true);
        setBackground(GUICommon.MY_COLOR);

        TextArea displayArea = new TextArea("", 10, 10, TextArea.SCROLLBARS_VERTICAL_ONLY);
        displayArea.setEditable(false);
        String textToDisplay = "";
        try {
            RandomAccessFile helpFile = new RandomAccessFile(paraFilename, "r");
            String tempLine = helpFile.readLine();
            while (tempLine != null) {
                textToDisplay = textToDisplay + tempLine + "\n";
                tempLine = helpFile.readLine();
            }
            helpFile.close();
        } catch (IOException ee) {
            dispose();
            ErrorDialog.errorDialog.setMessageAndShow(ee.toString());
        }
        // 如果需要显示中文就使用这个。用这个方法
        // method.
        // textToDisplay = SimpleTools.GB2312ToUNICODE(textToDisplay);
        displayArea.setText(textToDisplay);
        displayArea.setFont(new Font("Times New Romans", Font.PLAIN, 14));

        Button okButton = new Button("OK");
        okButton.setSize(20, 10);
        okButton.addActionListener(new DialogCloser(this));
        Panel okPanel = new Panel();
        okPanel.setLayout(new FlowLayout());
        okPanel.add(okButton);

        // OK 按钮
        setLayout(new BorderLayout());
        add(BorderLayout.CENTER, displayArea);
        add(BorderLayout.SOUTH, okPanel);

        setLocation(120, 70);
        setSize(500, 400);
        addWindowListener(new DialogCloser());
        setVisible(false);
    }
    
    /** 
    * @Description: 简单的激活使它可视化
    * @Param: [ee]
    * @return: void
    */
    public void actionPerformed(ActionEvent ee) {
        setVisible(true);
    }
}

第78天:GUI (2. 数据读取控件)

DoubleField.java 用于接受实型值, 如果不能解释成实型值会报错. 这样可以把用户的低级错误扼杀在摇篮中.

package MachineLearning.gui;

import java.awt.*;
import java.awt.event.FocusEvent;
import java.awt.event.FocusListener;

/**
 * @description:用于接收double值
 * @learner: Qing Zhang
 * @time: 07
 */
public class DoubleField extends TextField implements FocusListener {

    //Serial uid. 不一定有用
    private static final long serialVersionUID = 363634723L;

    //值
    protected double doubleValue;

    //赋予默认值
    public DoubleField() {
        this("5.13", 10);
    }// Of the first constructor

    //只指定内容
    public DoubleField(String paraString) {
        this(paraString, 10);
    }// Of the second constructor

    //只指定宽
    public DoubleField(int paraWidth) {
        this("5.13", paraWidth);
    }// Of the third constructor

    /**
    * @Description: 指定宽和长
    * @Param: [paraString, paraWidth]
    * @return:
    */
    public DoubleField(String paraString, int paraWidth) {
        super(paraString, paraWidth);
        addFocusListener(this);
    }

    /**
    * @Description:获得焦点事件
    * @Param: [paraEvent]
    * @return: void
    */
    public void focusGained(FocusEvent paraEvent) {
    }

    /**
    * @Description: 执行焦点的监听事件
    * @Param: [paraEvent]
    * @return: void
    */
    public void focusLost(FocusEvent paraEvent) {
        try {
            doubleValue = Double.parseDouble(getText());
        } catch (Exception ee) {
            ErrorDialog.errorDialog
                    .setMessageAndShow("\"" + getText() + "\" Not a double. Please check.");
            requestFocus();
        }
    }

    /**
    * @Description: 获取值
    * @Param: []
    * @return: double
    */
    public double getValue() {
        try {
            doubleValue = Double.parseDouble(getText());
        } catch (Exception ee) {
            ErrorDialog.errorDialog
                    .setMessageAndShow("\"" + getText() + "\" Not a double. Please check.");
            requestFocus();
        } 
        return doubleValue;
    }
}

IntegerField.java 同理.

package MachineLearning.gui;

import java.awt.*;
import java.awt.event.FocusEvent;
import java.awt.event.FocusListener;

/**
 * @description: 用于接收int值
 * @learner: Qing Zhang
 * @time: 07
 */
public class IntegerField extends TextField implements FocusListener {

    //Serial uid. 不一定有用
    private static final long serialVersionUID = -2462338973265150779L;

    //只指定内容
    public IntegerField() {
        this("513");
    }// Of constructor

    /** 
    * @Description: 指定宽和长
    * @Param: [paraString, paraWidth]
    * @return: 
    */
    public IntegerField(String paraString, int paraWidth) {
        super(paraString, paraWidth);
        addFocusListener(this);
    }

    //只指定内容
    public IntegerField(String paraString) {
        super(paraString);
        addFocusListener(this);
    }

    //只指定宽
    public IntegerField(int paraWidth) {
        super(paraWidth);
        setText("513");
        addFocusListener(this);
    }

    /**
     * @Description:获得焦点事件
     * @Param: [paraEvent]
     * @return: void
     */
    public void focusGained(FocusEvent paraEvent) {
    }

    /**
     * @Description: 执行焦点的监听事件
     * @Param: [paraEvent]
     * @return: void
     */
    public void focusLost(FocusEvent paraEvent) {
        try {
            Integer.parseInt(getText());
            // System.out.println(tempInt);
        } catch (Exception ee) {
            ErrorDialog.errorDialog.setMessageAndShow("\"" + getText()
                    + "\"Not an integer. Please check.");
            requestFocus();
        }
    }

    /**
     * @Description: 获取值
     * @Param: []
     * @return: int
     */
    public int getValue() {
        int tempInt = 0;
        try {
            tempInt = Integer.parseInt(getText());
        } catch (Exception ee) {
            ErrorDialog.errorDialog.setMessageAndShow("\"" + getText()
                    + "\" Not an int. Please check.");
            requestFocus();
        }
        return tempInt;
    }
}

FilenameField.java 则需要借助于系统提供的 FileDialog.

package MachineLearning.gui;

import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.FocusEvent;
import java.awt.event.FocusListener;
import java.io.File;

/**
 * @description:
 * @learner: Qing Zhang
 * @time: 07
 */
public class FilenameField extends TextField implements ActionListener, FocusListener {

    //Serial uid. 不一定有用
    private static final long serialVersionUID = 4572287941606065298L;

    /** 
    * @Description: 初始化
    * @Param: []
    * @return: 
    */
    public FilenameField() {
        super();
        setText("");
        addFocusListener(this);
    }

    /** 
    * @Description: 初始化
    * @Param: [paraWidth]
    * @return: 
    */
    public FilenameField(int paraWidth) {
        super(paraWidth);
        setText("");
        addFocusListener(this);
    }

    /** 
    * @Description: 初始化
    * @Param: [paraWidth, paraText]
    * @return: 
    */
    public FilenameField(int paraWidth, String paraText) {
        super(paraWidth);
        setText(paraText);
        addFocusListener(this);
    }

    /** 
    * @Description: 初始化
    * @Param: [paraText, paraWidth]
    * @return: 
    */
    public FilenameField(String paraText, int paraWidth) {
        super(paraWidth);
        setText(paraText);
        addFocusListener(this);
    }

    /** 
    * @Description: 避免null或者空串
    * @Param: [paraText]
    * @return: void
    */
    public void setText(String paraText) {
        if (paraText.trim().equals("")) {
            super.setText("unspecified");
        } else {
            super.setText(paraText.replace('\\', '/'));
        }
    }

    /** 
    * @Description: 执行活动监听
    * @Param: [paraEvent]
    * @return: void
    */
    public void actionPerformed(ActionEvent paraEvent) {
        FileDialog tempDialog = new FileDialog(GUICommon.mainFrame,
                "Select a file");
        tempDialog.setVisible(true);
        if (tempDialog.getDirectory() == null) {
            setText("");
            return;
        }

        String directoryName = tempDialog.getDirectory();

        String tempFilename = directoryName + tempDialog.getFile();
        //System.out.println("tempFilename = " + tempFilename);

        setText(tempFilename);
    }

    /** 
    * @Description: 执行焦点监听事件
    * @Param: [paraEvent]
    * @return: void
    */
    public void focusGained(FocusEvent paraEvent) {
    }

    /**
     * @Description: 执行焦点监听事件
     * @Param: [paraEvent]
     * @return: void
     */
    public void focusLost(FocusEvent paraEvent) {
        // System.out.println("Focus lost exists.");
        String tempString = getText();
        if ((tempString.equals("unspecified"))
                || (tempString.equals("")))
            return;
        File tempFile = new File(tempString);
        if (!tempFile.exists()) {
            ErrorDialog.errorDialog.setMessageAndShow("File \"" + tempString
                    + "\" not exists. Please check.");
            requestFocus();
            setText("");
        }
    }
}

第79天:GUI (3. 总体布局)

  1. 用了 GridLayout 和 BorderLayout 来组织控件.
  2. 按下 OK 执行 actionPerformed. 前两天已经有类似代码了.
package MachineLearning.gui;

import MachineLearning.ann.FullAnn;

import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.util.Date;

/**
 * @description:
 * @learner: Qing Zhang
 * @time: 07
 */
public class AnnMain implements ActionListener {

    //选择arff文件
    private FilenameField arffFilenameField;


    //设置α
    private DoubleField alphaField;


    //设置β
    private DoubleField betaField;

    //设置γ
    private DoubleField gammaField;


    //每层节点,如 "4, 8, 8, 3".
    private TextField layerNodesField;


    //激活函数的选择,例如 "ssa".
    private TextField activatorField;

    //训练次数
    private IntegerField roundsField;

    //学习率
    private DoubleField learningRateField;

    //mobp
    private DoubleField mobpField;


    //信息区域
    private TextArea messageTextArea;

    /**
    * @Description: 唯一的构造函数
    * @Param: []
    * @return:
    */
    public AnnMain() {
        //一个简单的窗口包含对话框
        Frame mainFrame = new Frame();
        mainFrame.setTitle("ANN. minfanphd@163.com");
        //顶部:选择arff文件
        arffFilenameField = new FilenameField(30);
        arffFilenameField.setText("d:/data/iris.arff");
        Button browseButton = new Button(" Browse ");
        browseButton.addActionListener(arffFilenameField);

        Panel sourceFilePanel = new Panel();
        sourceFilePanel.add(new Label("The .arff file:"));
        sourceFilePanel.add(arffFilenameField);
        sourceFilePanel.add(browseButton);

        //设置面板
        Panel settingPanel = new Panel();
        settingPanel.setLayout(new GridLayout(3, 6));

        settingPanel.add(new Label("alpha"));
        alphaField = new DoubleField("0.01");
        settingPanel.add(alphaField);

        settingPanel.add(new Label("beta"));
        betaField = new DoubleField("0.02");
        settingPanel.add(betaField);

        settingPanel.add(new Label("gamma"));
        gammaField = new DoubleField("0.03");
        settingPanel.add(gammaField);

        settingPanel.add(new Label("layer nodes"));
        layerNodesField = new TextField("4, 8, 8, 3");
        settingPanel.add(layerNodesField);

        settingPanel.add(new Label("activators"));
        activatorField = new TextField("sss");
        settingPanel.add(activatorField);

        settingPanel.add(new Label("training rounds"));
        roundsField = new IntegerField("5000");
        settingPanel.add(roundsField);

        settingPanel.add(new Label("learning rate"));
        learningRateField = new DoubleField("0.01");
        settingPanel.add(learningRateField);

        settingPanel.add(new Label("mobp"));
        mobpField = new DoubleField("0.5");
        settingPanel.add(mobpField);

        Panel topPanel = new Panel();
        topPanel.setLayout(new BorderLayout());
        topPanel.add(BorderLayout.NORTH, sourceFilePanel);
        topPanel.add(BorderLayout.CENTER, settingPanel);

        messageTextArea = new TextArea(80, 40);

        //底部:ok和exit
        Button okButton = new Button(" OK ");
        okButton.addActionListener(this);
        // DialogCloser dialogCloser = new DialogCloser(this);
        Button exitButton = new Button(" Exit ");
        // cancelButton.addActionListener(dialogCloser);
        exitButton.addActionListener(ApplicationShutdown.applicationShutdown);
        Button helpButton = new Button(" Help ");
        helpButton.setSize(20, 10);
        helpButton.addActionListener(new HelpDialog("ANN", "src/machinelearning/gui/help.txt"));
        Panel okPanel = new Panel();
        okPanel.add(okButton);
        okPanel.add(exitButton);
        okPanel.add(helpButton);

        mainFrame.setLayout(new BorderLayout());
        mainFrame.add(BorderLayout.NORTH, topPanel);
        mainFrame.add(BorderLayout.CENTER, messageTextArea);
        mainFrame.add(BorderLayout.SOUTH, okPanel);

        mainFrame.setSize(600, 500);
        mainFrame.setLocation(100, 100);
        mainFrame.addWindowListener(ApplicationShutdown.applicationShutdown);
        mainFrame.setBackground(GUICommon.MY_COLOR);
        mainFrame.setVisible(true);
    }


    /**
    * @Description: 读入arff文件
    * @Param: [ae]
    * @return: void
    */
    public void actionPerformed(ActionEvent ae) {
        String tempFilename = arffFilenameField.getText();

        // Read the layers nodes.
        String tempString = layerNodesField.getText().trim();

        int[] tempLayerNodes = null;
        try {
            tempLayerNodes = stringToIntArray(tempString);
        } catch (Exception ee) {
            ErrorDialog.errorDialog.setMessageAndShow(ee.toString());
            return;
        }

        double tempLearningRate = learningRateField.getValue();
        double tempMobp = mobpField.getValue();
        String tempActivators = activatorField.getText().trim();
        FullAnn tempNetwork = new FullAnn(tempFilename, tempLayerNodes, tempLearningRate, tempMobp,
                tempActivators);
        int tempRounds = roundsField.getValue();

        long tempStartTime = new Date().getTime();
        for (int i = 0; i < tempRounds; i++) {
            tempNetwork.train();
        }
        long tempEndTime = new Date().getTime();
        messageTextArea.append("\r\nSummary:\r\n");
        messageTextArea.append("Trainng time: " + (tempEndTime - tempStartTime) + "ms.\r\n");

        double tempAccuray = tempNetwork.test();
        messageTextArea.append("Accuracy: " + tempAccuray + "\r\n");
        messageTextArea.append("End.");
    }

    /** 
    * @Description: 将带逗号的字符串转换为int数组。
    * @Param: [paraString]
    * @return: int[]
    */
    public static int[] stringToIntArray(String paraString) throws Exception {
        int tempCounter = 1;
        for (int i = 0; i < paraString.length(); i++) {
            if (paraString.charAt(i) == ',') {
                tempCounter++;
            }
        }

        int[] resultArray = new int[tempCounter];

        String tempRemainingString = new String(paraString) + ",";
        String tempString;
        for (int i = 0; i < tempCounter; i++) {
            tempString = tempRemainingString.substring(0, tempRemainingString.indexOf(",")).trim();
            if (tempString.equals("")) {
                throw new Exception("Blank is unsupported");
            }

            resultArray[i] = Integer.parseInt(tempString);

            tempRemainingString = tempRemainingString
                    .substring(tempRemainingString.indexOf(",") + 1);
        }

        return resultArray;
    }


    public static void main(String args[]) {
        new AnnMain();
    }// Of main
}

在这里插入图片描述

第80天:GUI (4. 各种监听机制)

  1. 从监听机制、接口等角度, 分析在 GUI 上的各种操作分别会触发哪些代码;

    由于之前用C#写过winform程序,所以对于GUI上的事件响应,监听机制还是比较熟悉的,这里主要是使用了观察者设计模式,事件源注册事件监听器后,当事件源上发生某个动作时,事件源就会调用事件监听的一个方法,并将事件对象传递进去,开发者可以利用事件对象操作事件源。比如当操作鼠标点击某个部件时,可以将鼠标的点击事件触发,从而传递消息,比如是否点击以及鼠标在窗体上的位置等信息。

  2. 总结基础的人工神经网络.

迭代算法,随机设定参数的初始值,计算当前网络的输出,根据当前输出与样本决策标签的误差再反向传播,改变参数值,不断循环往复直至收敛至某一阈值。

缺点:

  1. 不知道你的神经网络将会如何产出结果,更不知道为什么会产生这种结果。
  2. 比较耗时;
  3. 难以找到大量有标签的数据;
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值