BP神经网络

BP神经网络

时间:2022/8/2

1.神经元模型

在这里插入图片描述

在生物神经网络中,最主要的结构便是神经元,如上图所示,便是生物神经元的结构模型,树突感知其他神经元传递的信息,通过轴突向后传播。神经元与神经元之间传播时通过一个突触的结构,通过神经递质改变下一个神经元的电位,当电位超过一定阈值,则信息将通过电位在下一个神经元上进行传递。故我们可将神经元的这种结构特性进行提取,构建神经元模型。

2.M-P神经元模型

在这里插入图片描述

M-P神经元模型是1943年由[McCulloch and Pitts,1943]将生物神经元抽象出来的神经元模型,一直沿用至今。在这个模型中,神经元接收n个其他神经元传递的输入,模拟树突。这些输入信号通过带权重的连接进行传递,神经元接受到的总输入值将与神经元的阈值进行比较,然后通过神经元的“激活函数”(ativation function)的处理以产生神经元的输出。虽然说看似是对生物神经元的模拟,但我感觉并不是真的模拟的生物神经元,感觉还是差别挺大的,归根结底,还是一个数学模型。
激活函数常采用阶跃函数sigmoid函数进行处理。下图便是sigmoid函数的图像,该函数可以将较大的数值变化,转换成较小的区间**[0,1]**里,有时也称之为挤压函数。

在这里插入图片描述

3.感知机与多层网络

感知机(Perceptron由两层神经元构成,如下图所示,输入层接收外界输入信号,输出层则是M-P神经元。其输出函数为:

y = f ( ∑ i w i x i − θ ) (1) y=f(\sum_i{w_ix_i}-\theta) \tag{1} y=f(iwixiθ)(1)
在这里插入图片描述

感知机能轻松实现逻辑与、或、非的运算。

对于感知机的训练过程,一般地,给定训练数据集,权重 w i ( i = 1 , 2 , . . , n ) w_i(i=1,2,..,n) wi(i=1,2,..,n)以及阈值 θ \theta θ可通过学习得到。阈值 θ \theta θ可以看作是一个固定输入-1.0的”哑结点“所对应的连接权重为 w n + 1 w_{n+1} wn+1,这样,将权重学习与阈值学习统一为权重的学习。而感知机的学习规则非常简单,对于训练样本 ( x , y ) (x,y) (x,y),若当前感知机的输出为 y ^ \hat y y^,则感知机的权值调整:

w i ← w i + Δ w i , Δ w i = η ( y − y ^ ) x i (2) w_i\leftarrow w_i+\Delta w_i,\\ \Delta w_i=\eta(y-\hat y)x_i \tag{2} wiwi+Δwi,Δwi=η(yy^)xi(2)

欲解决非线性问题,则需要多层功能神经元。如下图的三层神经元网络,在输入层和输出层之间再加入一层神经元,这一层称之为**“隐含层(hidden layer)”** 。通过如图的多层网络,即可解决异或的问题。
在这里插入图片描述

更一般的神经网络则是如下图所示的**“多层前馈神经网络”。即每层神经元与下层神经元完全互连,即“全连接”**。同一层神经元不相连,而且不能跨层连接。
在这里插入图片描述

输入层神经元只是接受数据输入,不进行任何处理,而隐含层和输出层则包含功能神经元。神经网络的学习过程,就是根据训练数据来调整神经元之间的“连接权”和功能神经元的阈值的过程。

3.误差反向传播算法(BP)

对于多层学习,其学习规则比感知机更为复杂。 而误差反向传播(error BackPropagation)则是最为成功的神经网络算法。如图所示,便是BP算法的网络结构。
在这里插入图片描述

3.1 正向传播

对于训练集 ( X k , y k ) (X_k,y_k) Xk,yk,假定神经网络的输出为 y ^ = ( y ^ 1 k , y ^ 2 k , . . . , y ^ l k ) \hat y =(\hat y _1^k,\hat y _2^k,...,\hat y _l^k) y^=(y^1k,y^2k,...,y^lk).则,设 β j = ∑ w h j x i \beta_j=\sum w_{hj}x_i βj=whjxi为第j个输出神经元的输入, θ j \theta_j θj为第j个输出神经元的阈值, w h j w_{hj} whj为隐含层第h个神经元到第j个输出神经元的连接权,激活函数为sigmoid,其输出函数为:

y ^ j k = f ( β j − θ j ) (3) \hat y_j^k=f(\beta_j-\theta_j) \tag{3} y^jk=f(βjθj)(3)

则网络在该数据集 ( X k , y k ) (X_k,y_k) (Xk,yk)上的方差为:

E k = 1 2 ∑ j = 1 l ( y ^ j k − y j k ) 2 (4) E_k={1\over2}\sum_{j=1}^l(\hat y_j^k-y_j^k)^2\tag{4} Ek=21j=1l(y^jkyjk)2(4)

3.2 反向传播

而BP算法的目标则是最小化网络的均方误差。其采用的优化策略为梯度下降法。根据梯度下降计算出的梯度项进行参数调整。

Δ w h j = − η ∂ E k ∂ w h j (5) \Delta w_{hj}=-\eta\frac{ \partial E_k }{ \partial w_{hj}}\tag{5} Δwhj=ηwhjEk(5)

则按照梯度下降的计算公式,对于隐含层到输出层的权值调整为:

Δ w h j = η g j b h (6) \Delta w_{hj}=\eta g_jb_h \tag{6} Δwhj=ηgjbh(6)

其中 b h b_h bh为第h个隐含层神经元的输出。 g j g_j gj为梯度项:

g j = − ∂ E k ∂ y ^ j k ∗ y ^ j k ∂ β j = y ^ j k ( 1 − y ^ j k ) ( y j k − y ^ j k ) (7) \begin{split} g_j & =-\frac{ \partial E_k }{ \partial \hat y_j^k}*\frac{ \hat y_j^k}{ \partial \beta _j} \\ &=\hat y_j^k(1-\hat y_j^k)(y_j^k-\hat y_j^k) \end{split}\tag{7} gj=y^jkEkβjy^jk=y^jk(1y^jk)(yjky^jk)(7)

同理可得:

Δ θ j = − η g j (8) \Delta \theta_j=-\eta g_j \tag{8} Δθj=ηgj(8)

Δ w i h = η e h x 1 (9) \Delta w_{ih}=\eta e_h x_1 \tag{9} Δwih=ηehx1(9)

Δ γ h = − η e h (10) \Delta \gamma_h=-\eta e_h \tag{10} Δγh=ηeh(10)

其中 e h e_h eh为隐含层权值调整梯度项,设第h个隐含层的输入为 α h \alpha_h αh,则:

e h = − ∂ E k ∂ b h ∗ b h ∂ α h = b h ( 1 − b h ) ∑ j = 1 l w h i g j (11) \begin{split} e_h & =-\frac{ \partial E_k }{ \partial b_h}*\frac{ b_h}{ \partial \alpha_h} \\ &=b_h(1-b_h)\sum_{j=1}^lw_{hi}g_j \end{split}\tag{11} eh=bhEkαhbh=bh(1bh)j=1lwhigj(11)

对任意参数v其更新公式为:

v ← v + Δ v (12) v\leftarrow v+\Delta v \tag{12} vv+Δv(12)

4.算法流程

输入:训练集 D = { ( X k , y k ) } m D=\{(X_k,y_k)\}^m D={(Xk,yk)}m,学习率 η \eta η

过程:

1.在(0,1)的范围内随机初始化网络中的所有参数。

2.repeat

  • for all( ( X k , y k ) ∈ D (X_k,y_k)\in D (Xk,yk)D) do

    • 正向传播,根据当前参数和输入数据计算当前样本的输出 y ^ j k \hat y_j^k y^jk

    • 根据公式(7)计算输出层神经元的梯度项 g j g_j gj

    • 根据公式(11)计算隐含层神经元梯度项 e h e_h eh

    • 根据公式(6)(8-10)(12)更新连接权与阈值

    end for

    util 达到停止条件

输出:训练好的神经网络

5.算法实现

1.GeneralAnn.java

抽象类,定义ANN的基础结构

/**
 * GeneralAnn.java
 *
 * @Author zjy
 * @Date 2022/7/28
 * @Description: 抽象类,定义ANN的基础结构
 * @Version V1.0
 */

package swpu.zjy.ML.ANN.myAnn;

import weka.core.Instances;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;


public abstract class GeneralAnn {
    //数据集
    Instances dataset;
    //神经网络层数
    int numLayers;
    //每层神经网络结点数量
    int[] numLayerNodes;
    //动量,用于加速梯度下降
    public double mobp;
    //梯度下降学习率
    public double learningRate;
    //随机数生成器
    Random random = new Random();

    /**
     * 构造器,初始化网络参数
     *
     * @param datasetFileName   数据集地址
     * @param paraNumLayerNodes 神经网络每层结点数量
     * @param paraMobp          动量
     * @param paraLR            学习率
     */
    public GeneralAnn(String datasetFileName, int[] paraNumLayerNodes, double paraMobp, double paraLearnRate) {
        FileReader fileReader = null;
        try {
            fileReader = new FileReader(datasetFileName);
            dataset = new Instances(fileReader);
            dataset.setClassIndex(dataset.numAttributes() - 1);
            fileReader.close();
        } catch (Exception e) {
            System.out.println(e);
        }

        numLayerNodes = paraNumLayerNodes;
        numLayers = numLayerNodes.length;
        numLayerNodes[0] = dataset.numAttributes() - 1;
        numLayerNodes[numLayers-1]=dataset.numClasses();
        learningRate = paraLearnRate;
        mobp = paraMobp;
    }

    /**
     * 抽象方法,向前传播,输出预测结果
     *
     * @param paraInput 输入数据
     * @return 预测结果
     */
    public abstract double[] forward(double[] paraInput);

    /**
     * 抽象方法,向后传播,调整网络参数
     *
     * @param paraTarget 预计目标结果
     */
    public abstract void backPropagation(double[] paraTarget);

    /**
     * 使用数据集进行训练
     */
    public void train() {
        double[] tempInput = new double[dataset.numAttributes() - 1];
        double[] tempTarget = new double[dataset.numClasses()];

        for (int i = 0; i < dataset.numInstances(); i++) {
            /**
             * 填充输入数据
             */
            for (int j = 0; j < tempInput.length; j++) {
                tempInput[j] = dataset.instance(i).value(j);
            }
            /**
             * 填充预计目标
             */
            Arrays.fill(tempTarget, 0);
            tempTarget[(int) dataset.instance(i).classValue()] = 1;

            forward(tempInput);
            backPropagation(tempTarget);
        }
    }

    /**
     * 获取最大值下标
     *
     * @param paraArray 数组
     * @return 最大值下标
     */
    public static int argmax(double[] paraArray) {
        int resultIndex = -1;
        double tempMax = -1e10;
        for (int i = 0; i < paraArray.length; i++) {
            if (tempMax < paraArray[i]) {
                tempMax = paraArray[i];
                resultIndex = i;
            }
        }
        return resultIndex;
    }

    /**
     * 进行测试
     *
     * @return AUC
     */
    public double test() {
        double[] tempInput = new double[dataset.numAttributes() - 1];

        double tempNumCorrect = 0;
        double[] tempPrediction;
        int[] predict=new int[dataset.numInstances()];
        Arrays.fill(predict,0);
        int tempPredictedClass = -1;

        for (int i = 0; i < dataset.numInstances(); i++) {
            for (int j = 0; j < tempInput.length; j++) {
                tempInput[j] = dataset.instance(i).value(j);
            }

            tempPrediction = forward(tempInput);

            tempPredictedClass = argmax(tempPrediction);
            predict[i]=tempPredictedClass;
            if (tempPredictedClass == (int) dataset.instance(i).classValue()) {
                tempNumCorrect++;
            }
        }
        System.out.println("Correct:" + tempNumCorrect + "out of" + dataset.numInstances());
        return tempNumCorrect / dataset.numInstances();
    }

}

2.Activator.java

激活函数封装

package swpu.zjy.ML.ANN.myAnn;

/**
 * Activator.java
 *
 * @Author Fan Min minfanphd@163.com.
 * @Date 2022/7/29
 * @Description: 激活函数
 * @Version V1.0
 */


public class Activator {
    public final char ARC_TAN='a';
    public final char ELU='e';
    public final char GELU= 'g';
    public final char HARD_LOGISTIC='h';
    public final char IDENTITY = 'i';
    public final char LEAKY_RELU = 'l';
    public final char RELU = 'r';
    public final char SOFT_SIGN = 'o';
    public final char SIGMOID = 's';
    public final char TANH = 't';
    public final char SOFT_PLUS = 'u';
    public final char SWISH = 'w';
    private char activator;
    double alpha;
    double beta;
    double gamma;

    /**
     * 构造器。设置激活函数类型
     * @param activator 激活函数类型
     */
    public Activator(char activator) {
        this.activator = activator;
    }

    public char getActivator() {
        return activator;
    }

    public void setActivator(char activator) {
        this.activator = activator;
    }

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    public void setBeta(double beta) {
        this.beta = beta;
    }

    public void setGamma(double gamma) {
        this.gamma = gamma;
    }

    /**
     * 激活函数
     * @param paraValue 欲激活数据
     * @return 激活后的数据
     */
    public double activate(double paraValue) {
        double resultValue = 0;
        switch (activator) {
            case ARC_TAN:
                resultValue = Math.atan(paraValue);
                break;
            case ELU:
                if (paraValue >= 0) {
                    resultValue = paraValue;
                } else {
                    resultValue = alpha * (Math.exp(paraValue) - 1);
                } // Of if
                break;
            // case GELU:
            // resultValue = ?;
            // break;
            // case HARD_LOGISTIC:
            // resultValue = ?;
            // break;
            case IDENTITY:
                resultValue = paraValue;
                break;
            case LEAKY_RELU:
                if (paraValue >= 0) {
                    resultValue = paraValue;
                } else {
                    resultValue = alpha * paraValue;
                } // Of if
                break;
            case SOFT_SIGN:
                if (paraValue >= 0) {
                    resultValue = paraValue / (1 + paraValue);
                } else {
                    resultValue = paraValue / (1 - paraValue);
                } // Of if
                break;
            case SOFT_PLUS:
                resultValue = Math.log(1 + Math.exp(paraValue));
                break;
            case RELU:
                if (paraValue >= 0) {
                    resultValue = paraValue;
                } else {
                    resultValue = 0;
                } // Of if
                break;
            case SIGMOID:
                resultValue = 1 / (1 + Math.exp(-paraValue));
                break;
            case TANH:
                resultValue = 2 / (1 + Math.exp(-2 * paraValue)) - 1;
                break;
            // case SWISH:
            // resultValue = ?;
            // break;
            default:
                System.out.println("Unsupported activator: " + activator);
                System.exit(0);
        }// Of switch

        return resultValue;
    }// Of activate

    /**
     * 激活函数求导,用于反向传播
     * @param paraValue 源数据
     * @param paraActivatedValue f(x)
     * @return
     */
    public double derive(double paraValue,double paraActivatedValue) {
        double resultValue = 0;
        switch (activator) {
            case ARC_TAN:
                resultValue = 1 / (paraValue * paraValue + 1);
                break;
            case ELU:
                if (paraValue >= 0) {
                    resultValue = 1;
                } else {
                    resultValue = alpha * (Math.exp(paraValue) - 1) + alpha;
                } // Of if
                break;
            // case GELU:
            // resultValue = ?;
            // break;
            // case HARD_LOGISTIC:
            // resultValue = ?;
            // break;
            case IDENTITY:
                resultValue = 1;
                break;
            case LEAKY_RELU:
                if (paraValue >= 0) {
                    resultValue = 1;
                } else {
                    resultValue = alpha;
                } // Of if
                break;
            case SOFT_SIGN:
                if (paraValue >= 0) {
                    resultValue = 1 / (1 + paraValue) / (1 + paraValue);
                } else {
                    resultValue = 1 / (1 - paraValue) / (1 - paraValue);
                } // Of if
                break;
            case SOFT_PLUS:
                resultValue = 1 / (1 + Math.exp(-paraValue));
                break;
            case RELU: // Updated
                if (paraValue >= 0) {
                    resultValue = 1;
                } else {
                    resultValue = 0;
                } // Of if
                break;
            case SIGMOID: // Updated
                resultValue = paraActivatedValue * (1 - paraActivatedValue);
                break;
            case TANH: // Updated
                resultValue = 1 - paraActivatedValue * paraActivatedValue;
                break;
            // case SWISH:
            // resultValue = ?;
            // break;
            default:
                System.out.println("Unsupported activator: " + activator);
                System.exit(0);
        }// Of switch

        return resultValue;
    }
    public String toString() {
        String resultString = "Activator with function '" + activator + "'";
        resultString += "\r\n alpha = " + alpha + ", beta = " + beta + ", gamma = " + gamma;

        return resultString;
    }

    public static void main(String[] args) {
        Activator tempActivator = new Activator('s');
        double tempValue = 0.6;
        double tempNewValue;
        tempNewValue = tempActivator.activate(tempValue);
        System.out.println("After activation: " + tempNewValue);

        tempNewValue = tempActivator.derive(tempValue, tempNewValue);
        System.out.println("After derive: " + tempNewValue);
    }
}

3.AnnLayer.java

Ann 层的实现

package swpu.zjy.ML.ANN.myAnn;

import java.util.Arrays;
import java.util.Random;

/**
 * AnnLayer.java
 *
 * @Author zjy
 * @Date 2022/7/30
 * @Description: ANN Layer定义
 * @Version V1.0
 */


public class AnnLayer {
    //输入数据数量
    int numInput;
    //输出数据数量
    int numOutput;
    //学习率
    double learningRate;
    //动量
    double mobp;
    //权值
    double[][] weights;
    //权值改变量
    double[][] deltaWeights;
    //输出误差
    double[] errors;
    //输入数据
    double[] input;
    //输出数据
    double[] output;
    //激活后的输出数据
    double[] activatedOutput;
    //激活函数
    Activator activator;
    //随机数生成器
    Random random=new Random();

    /**
     * 构造器,初始化层相关参数
     * @param paraNumInput  输出数据数量
     * @param paraNumOutput 输出数据数量
     * @param paraActivator 激活函数类型
     * @param paraLearningRate 学习率
     * @param paraMobp 动量
     */
    public AnnLayer(int paraNumInput, int paraNumOutput, char paraActivator,
                    double paraLearningRate, double paraMobp){
        numInput=paraNumInput;
        numOutput=paraNumOutput;
        learningRate=paraLearningRate;
        mobp=paraMobp;
        weights=new double[numInput+1][numOutput];
        deltaWeights=new double[numInput+1][numOutput];
        //初始化权值
        for (int i = 0; i < numInput+1; i++) {
            for (int j = 0; j < numOutput; j++) {
                weights[i][j]=random.nextDouble();
            }
        }

        errors=new double[numInput];
        input=new double[numInput];
        output=new double[numOutput];
        activatedOutput=new double[numOutput];
        activator=new Activator(paraActivator);
    }

    /**
     * 设置激活函数参数
     * @param paraAlpha
     * @param paraBeta
     * @param paraGamma
     */
    public void setParameters(double paraAlpha, double paraBeta, double paraGamma) {
        activator.setAlpha(paraAlpha);
        activator.setBeta(paraBeta);
        activator.setGamma(paraGamma);
    }

    /**
     * 向前传播
     * @param paraInput 输入数据
     * @return 输出结果
     */
    public double[] forword(double[] paraInput){
        for (int i = 0; i < numInput; i++) {
            input[i]=paraInput[i];
        }
        for (int i = 0; i < numOutput; i++) {
            //读入偏移量
            output[i]=weights[numInput][i];
            for (int j = 0; j < numInput; j++) {
                output[i]+=input[j]*weights[j][i];
            }
            activatedOutput[i]=activator.activate(output[i]);
        }
        return activatedOutput;
    }

    /**
     * 单层反向传播
     * @param paraErrors 后一层的传播误差
     * @return 向前一层传播的误差
     */
    public double[] backPropagation(double[] paraErrors){
        //计算该层梯度项
        for (int i = 0; i < paraErrors.length; i++) {
            paraErrors[i]=activator.derive(output[i],activatedOutput[i])*paraErrors[i];
        }
        //权值调整
        for (int i = 0; i < numInput; i++) {
            errors[i]=0;
            //更新连接权值
            for (int j = 0; j < numOutput; j++) {
                errors[i] += paraErrors[j] * weights[i][j];
                deltaWeights[i][j]=mobp*deltaWeights[i][j]+learningRate*paraErrors[j]*input[i];
                weights[i][j]+=deltaWeights[i][j];
            }
        }
        //更新阈值
        for (int j = 0; j < numOutput; j++) {
            deltaWeights[numInput][j]=mobp*deltaWeights[numInput][j]+learningRate*paraErrors[j];
            weights[numInput][j]+=deltaWeights[numInput][j];
        }
        return errors;
    }

    /**
     * 获取后一层误差
     * @param paraTarget 预计目标
     * @return 本次传播误差
     */
    public double[] getLastLayerErrors(double[] paraTarget){
        double[] resultErrors=new double[numOutput];
        for (int i = 0; i < numOutput; i++) {
            resultErrors[i]=(paraTarget[i]-activatedOutput[i]);
        }
        return resultErrors;
    }
    public String toString() {
        String resultString = "";
        resultString += "Activator: " + activator;
        resultString += "\r\n weights = " + Arrays.deepToString(weights);
        return resultString;
    }

    public static void unitTest() {
        swpu.zjy.ML.ANN.teacher.AnnLayer tempLayer = new swpu.zjy.ML.ANN.teacher.AnnLayer(2, 3, 's', 0.01, 0.1);
        double[] tempInput = { 1, 4 };

        System.out.println(tempLayer);

        double[] tempOutput = tempLayer.forward(tempInput);
        System.out.println("Forward, the output is: " + Arrays.toString(tempOutput));

        double[] tempError = tempLayer.backPropagation(tempOutput);
        System.out.println("Back propagation, the error is: " + Arrays.toString(tempError));
    }

    public static void main(String[] args) {
        unitTest();
    }
}

4.FullAnn.java

BP神经网络实现

package swpu.zjy.ML.ANN.myAnn;

/**
 * FullAnn.java
 *
 * @Author zjy
 * @Date 2022/7/30
 * @Description: BP神经网络实现
 * @Version V1.0
 */


public class FullAnn extends GeneralAnn{
    //神经网络层数
    AnnLayer[] layers;

    /**
     * 构造器,初始化网络参数,构造神经网络结构
     *
     * @param datasetFileName   数据集地址
     * @param paraNumLayerNodes 神经网络每层结点数量
     * @param paraMobp          动量
     * @param paraLearnRate            学习率
     */
    public FullAnn(String datasetFileName, int[] paraNumLayerNodes, double paraMobp, double paraLearnRate,String paraActivator) {
        super(datasetFileName, paraNumLayerNodes, paraMobp, paraLearnRate);

        //生成层
        layers=new AnnLayer[numLayers-1];
        for (int i = 0; i < layers.length; i++) {
            layers[i]=new AnnLayer(numLayerNodes[i],numLayerNodes[i+1],paraActivator.charAt(i),paraLearnRate,paraMobp);
        }
    }

    /**
     * 前向传播
     * @param paraInput 输入数据
     * @return 预测结果
     */
    @Override
    public double[] forward(double[] paraInput) {
        double[] resultArray=paraInput;
        for (int i = 0; i < numLayers-1; i++) {
            resultArray=layers[i].forword(resultArray);
        }
        return resultArray;
    }

    /**
     *
     * @param paraTarget 预计目标结果
     */
    @Override
    public void backPropagation(double[] paraTarget) {
        double[] tempErrors=layers[numLayers-2].getLastLayerErrors(paraTarget);
        for (int i = numLayers - 2; i >= 0; i--){
            tempErrors=layers[i].backPropagation(tempErrors);
        }
        return;
    }
    public String toString() {
        return "I am a full ANN with " + numLayers + " layers";
    }

    public static void main(String[] args) {
        int[] tempLayerNodes = { 4, 8, 8, 3 };
        FullAnn tempNetwork = new FullAnn("src/main/java/swpu/zjy/ML/DataSet/iris.arff", tempLayerNodes, 0.6,0.01,
                 "sss");

        for (int round = 0; round < 5000; round++) {
            tempNetwork.train();
        } // Of for n

        double tempAccuray = tempNetwork.test();
        System.out.println("The accuracy is: " + tempAccuray);
        System.out.println("FullAnn ends.");
    }
}
String() {
        return "I am a full ANN with " + numLayers + " layers";
    }

    public static void main(String[] args) {
        int[] tempLayerNodes = { 4, 8, 8, 3 };
        FullAnn tempNetwork = new FullAnn("src/main/java/swpu/zjy/ML/DataSet/iris.arff", tempLayerNodes, 0.6,0.01,
                 "sss");

        for (int round = 0; round < 5000; round++) {
            tempNetwork.train();
        } // Of for n

        double tempAccuray = tempNetwork.test();
        System.out.println("The accuracy is: " + tempAccuray);
        System.out.println("FullAnn ends.");
    }
}

5.运行测试

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值