神经网络模型搭建及Java实例

首先什么是人工神经网络?简单来说就是将单个感知器作为一个神经网络节点,然后用此类节点组成一个层次网络结构,我们称此网络即为人工神经网络(本人自己的理解)。当网络的层次大于等于3层(输入层+隐藏层(大于等于1)+输出层)时,我们称之为多层人工神经网络。

1、神经单元的选择

  那么我们应该使用什么样的感知器来作为神经网络节点呢?在上一篇文章我们介绍过感知器算法,但是直接使用的话会存在以下问题:

  1)感知器训练法则中的输出

  由于sign函数时非连续函数,这使得它不可微,因而不能使用上面的梯度下降算法来最小化损失函数。

  2)增量法则中的输出为;

  每个输出都是输入的线性组合,这样当多个线性单元连接在一起后最终也只能得到输入的线性组合,这和只有一个感知器单元节点没有很大不同。 

  为了解决上面存在的问题,一方面,我们不能直接使用线性组合的方式直接输出,需要在输出的时候添加一个处理函数;另一方面,添加的处理函数一定要是可微的,这样我们才能使用梯度下降算法。

  满足上面条件的函数非常的多,但是最经典的莫过于sigmoid函数,又称Logistic函数,此函数能够将内的任意数压缩到(0,1)之间,因此这个函数又称为挤压函数。为了将此函数的输入更加规范化,我们在输入的线性组合中添加一个阀值,使得输入的线性组合以0为分界点。

  sigmoid函数:

  其函数曲线如图1.1所示。

图1.1 sigmoid函数曲线[2]

  此函数有个重要特性就是他的导数:

 

  有了此特性在计算它的梯度下降时就简便了很多。

  另外还有双曲函数tanh也可以用来替代sigmoid函数,二者的曲线图比较类似。

 

2、反向传播算法又称BP算法(Back Propagation)

         现在,我们可以用上面介绍的使用sigmoid函数的感知器来搭建一个多层神经网络,为简单起见,此处我们使用三层网络来分析。假设网络拓补如图2.1所示。

图2.1 BP网络拓补结构[3]

  网络的运行流程为:当输入一个样例后,获得该样例的特征向量,再根据权向量得到感知器的输入值,然后使用sigmoid函数计算出每个感知器的输出,再将此输出作为下一层感知器的输入,依次类推,直到输出层。

  那么如何确定每个感知器的权向量呢?这时我们需要使用反向传播算法来逐步进行优化。在正式介绍反向传播算法之前,我们先继续进行分析。

  在上一篇介绍感知器的文章中,为了得到权向量,我们通过最小化损失函数来不断调整权向量。此方法也适用于此处求解权向量,首先我们需要定义损失函数,由于网络的输出层有多个输出结点,我们需要将输出层每个输出结点的差值平方求和。于是得到每一个训练样例的损失函数为:(前面加个0.5方便后面求导使用)

  在多层的神经网络中,误差曲面可能有多个局部极小值,这意味着使用梯度下降算法找到的可能是局部极小值,而不是全局最小值。

  现在我们有了损失函数,这时可以根据损失函数来调整输出结点中的输入权向量,这类似感知器中的随机梯度下降算法,然后从后向前逐层调整权重,这就是反向传播算法的思想。

 

具有两层sigmoid单元的前馈网络的反向传播算法:

1)将网络中的所有权值随机初始化。

2)对每一个训练样例,执行如下操作:

  A)根据实例的输入,从前向后依次计算,得到输出层每个单元的输出。然后从输出层开始反向计算每一层的每个单元的误差项。

  B)对于输出层的每个单元k,计算它的误差项:

  C)对于网络中每个隐藏单元h,计算它的误差项:

  D)更新每个权值:

符号说明:

xji:结点i到结点j的输入,wji表示对应的权值。

outputs:表示输出层结点集合。

整个算法与delta法则的随机梯度下降算法类似,算法分析如下:

  1)权值的更新方面,和delta法则类似,主要依靠学习速率,该权值对应的输入,以及单元的误差项。

  2)对输出层单元,它的误差项是(t-o)乘以sigmoid函数的导数ok(1-ok),这与delta法则的误差项有所不同,delta法则的误差项为(t-o)。

  3)对于隐藏层单元,因为缺少直接的目标值来计算隐藏单元的误差,因此需要以间接的方式来计算隐藏层的误差项对受隐藏单元h影响的每一个单元的误差进行加权求和,每个误差权值为wkh, wkh就是隐藏单元h到输出单元k的权值。

 

3、反向传播算法的推导

  算法的推导过程主要是利用梯度下降算法最小化损失函数的过程,现在损失函数为:

  对于网络中的每个权值wji,计算其导数:

  1)若j是网络的输出层单元

  对netj的求导:

  其中:

  

  

  所以有:

  为了使表达式简洁,我们使用:

  权值的改变朝着损失函数的负梯度方向,于是有权值改变量:

 

2)若j是网络中的隐藏单元

  由于隐藏单元中w的值通过下一层来间接影响输入,故使用逐层剥离的方式来进行求导:

  因为:

  所以:

  同样,我们使用:

  所以权值变化量:

 

4、算法的改进

  反向传播算法的应用非常的广泛,为了满足各种不同的需求,产生了很多不同的变体,下面介绍两种变体:

  1)增加冲量项

  此方法主要是修改权值更新法则。他的主要思想在于让第n次迭代时的权值的更新部分依赖于第n-1次的权值。

  其中0<=a<1:称为冲量的系数。加入冲量项在一定程度上起到加大搜索步长的效果,从而能更快的进行收敛。另一方面,由于多层网络易导致损失函数收敛到局部极小值,但通过冲量项在某种程度上可以越过某些狭窄的局部极小值,达到更小的地方。

  2)学习任意的深度的无环网络

  在上述介绍的反向传播算法实际只有三层,即只有一层隐藏层的情况,要是有很多隐藏层应当如何进行处理?

  现假设神经网络共有m+2层,即有m层的隐藏层。这时,只需要变化一个地方即可得到具有m个隐藏层的反向传播算法。第k层的单元r的误差 的值由更深的第k+1层的误差项计算得到:


      下边用Java进行神经网络的实例搭建。

BPNN类,实现神经网络的主要算法。

package bpnn;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;

/**
 * BP neural network core code and prediction processing code
 * @author Pumpkin
 * @since 2018/04/02
 * @version 1.0
 */
public class BPNN {

    
    // private static int LAYER = 3; // Three-layer neural network
    private static final int NodeNum = 10; // The maximum number of nodes per layer
    private static final int ADJUST = 5; // Hidden layer node adjustment constant
    private static final int MaxTrain = 2000; // Maximum training times
    private static final double ACCU = 0.015; // Allowable error for each iteration
    private double ETA_W = 0.5; // Weight learning efficiency
    private double ETA_T = 0.5; // Threshold learning efficiency
    private double accu;

    // 附加动量项
    //private static final double ETA_A = 0.3; // 动量常数0.1
    //private double[][] in_hd_last; // 上一次的权值调整量
    //private double[][] hd_out_last;

    private int in_num; // The number of input layer nodes
    private int hd_num; // The number of hidden layer nodes
    private int out_num; // The number of output layer nodes

    private ArrayList<ArrayList<Double>> list = new ArrayList<>(); // Input and output data

    private double[][] in_hd_weight; // BP network in-hidden synaptic weight
    private double[][] hd_out_weight; // BP network hidden_out synaptic weight
    private double[] in_hd_th; // BP network in-hidden threshold
    private double[] hd_out_th; // BP network hidden-out threshold

    private double[][] out; // The output value of each neuron converted by the sigmoid function, the input layer is the original value
    private double[][] delta; // Delta learning rules

    /** Obtaining the largest number of neurons in the third layer of the network
     * @return  **/
    public int GetMaxNum() {
        return Math.max(Math.max(in_num, hd_num), out_num);
    }

    // Setting the weight learning rate
    public void SetEtaW() {
        ETA_W = 0.5;
    }

    // Set threshold learning rate
    public void SetEtaT() {
        ETA_T = 0.5;
    }

    // BP neural network training
    public void Train(int in_number, int out_number,
            ArrayList<ArrayList<Double>> arraylist) throws IOException {
        list = arraylist;
        in_num = in_number;
        out_num = out_number;

        GetNums(in_num, out_num); // Get the number of nodes in the input layer, hidden layer, and output layer
        InitNetWork(); // Initialize network weights and thresholds

        int datanum = list.size(); // Number of training data 
        int createsize = GetMaxNum(); // Compare to create an array that stores the output data for each layer
        out = new double[3][createsize];

        for (int iter = 0; iter < MaxTrain; iter++) {
            for (int cnd = 0; cnd < datanum; cnd++) {
                // 第一层输入节点赋值

                for (int i = 0; i < in_num; i++) {
                    out[0][i] = list.get(cnd).get(i); // 为输入层节点赋值,其输入与输出相同
                }
                Forward(); // 前向传播
                Backward(cnd); // 误差反向传播

            }
            System.out.println("This is the " + (iter + 1)
                    + " th trainning NetWork !");
            accu = GetAccu();
            System.out.println("All Samples Accuracy is " + accu);
            if (accu < ACCU)
                break;

        }

    }

    // 获取输入层、隐层、输出层的节点数,in_number、out_number分别为输入层节点数和输出层节点数
    public void GetNums(int in_number, int out_number) {
        in_num = in_number;
        out_num = out_number;
        hd_num = (int) Math.sqrt(in_num + out_num) + ADJUST;
        if (hd_num > NodeNum)
            hd_num = NodeNum; // 隐层节点数不能大于最大节点数
    }

    // 初始化网络的权值和阈值
    public void InitNetWork() {
        // 初始化上一次权值量,范围为-0.5-0.5之间
        //in_hd_last = new double[in_num][hd_num];
        //hd_out_last = new double[hd_num][out_num];

        in_hd_weight = new double[in_num][hd_num];
        for (int i = 0; i < in_num; i++)
            for (int j = 0; j < hd_num; j++) {
                int flag = 1; // 符号标志位(-1或者1)
                if ((new Random().nextInt(2)) == 1)
                    flag = 1;
                else
                    flag = -1;
                in_hd_weight[i][j] = (new Random().nextDouble() / 2) * flag; // 初始化in-hidden的权值
                //in_hd_last[i][j] = 0;
            }

        hd_out_weight = new double[hd_num][out_num];
        for (int i = 0; i < hd_num; i++)
            for (int j = 0; j < out_num; j++) {
                int flag = 1; // 符号标志位(-1或者1)
                if ((new Random().nextInt(2)) == 1)
                    flag = 1;
                else
                    flag = -1;
                hd_out_weight[i][j] = (new Random().nextDouble() / 2) * flag; // 初始化hidden-out的权值
                //hd_out_last[i][j] = 0;
            }

        // 阈值均初始化为0
        in_hd_th = new double[hd_num];
        for (int k = 0; k < hd_num; k++)
            in_hd_th[k] = 0;

        hd_out_th = new double[out_num];
        for (int k = 0; k < out_num; k++)
            hd_out_th[k] = 0;

    }

    // 计算单个样本的误差
    public double GetError(int cnd) {
        double ans = 0;
        for (int i = 0; i < out_num; i++)
            ans += 0.5 * (out[2][i] - list.get(cnd).get(in_num + i))
                    * (out[2][i] - list.get(cnd).get(in_num + i));
        return ans;
    }

    // 计算所有样本的平均精度
    public double GetAccu() {
        double ans = 0;
        int num = list.size();
        for (int i = 0; i < num; i++) {
            int m = in_num;
            for (int j = 0; j < m; j++)
                out[0][j] = list.get(i).get(j);
            Forward();
            int n = out_num;
            for (int k = 0; k < n; k++)
                ans += 0.5 * (list.get(i).get(in_num + k) - out[2][k])
                        * (list.get(i).get(in_num + k) - out[2][k]);
        }
        return ans / num;
    }

    // 前向传播
    public void Forward() {
        // 计算隐层节点的输出值
        for (int j = 0; j < hd_num; j++) {
            double v = 0;
            for (int i = 0; i < in_num; i++)
                v += in_hd_weight[i][j] * out[0][i];
            v += in_hd_th[j];
            out[1][j] = Sigmoid(v);
        }
        // 计算输出层节点的输出值
        for (int j = 0; j < out_num; j++) {
            double v = 0;
            for (int i = 0; i < hd_num; i++)
                v += hd_out_weight[i][j] * out[1][i];
            v += hd_out_th[j];
            out[2][j] = Sigmoid(v);
        }
    }

    // 误差反向传播
    public void Backward(int cnd) {
        CalcDelta(cnd); // 计算权值调整量
        UpdateNetWork(); // 更新BP神经网络的权值和阈值
    }

    // 计算delta调整量
    public void CalcDelta(int cnd) {

        int createsize = GetMaxNum(); // 比较创建数组
        delta = new double[3][createsize];
        // 计算输出层的delta值
        for (int i = 0; i < out_num; i++) {
            delta[2][i] = (list.get(cnd).get(in_num + i) - out[2][i])
                    * SigmoidDerivative(out[2][i]);
        }

        // 计算隐层的delta值
        for (int i = 0; i < hd_num; i++) {
            double t = 0;
            for (int j = 0; j < out_num; j++)
                t += hd_out_weight[i][j] * delta[2][j];
            delta[1][i] = t * SigmoidDerivative(out[1][i]);
        }
    }

    // 更新BP神经网络的权值和阈值
    public void UpdateNetWork() {

        // 隐含层和输出层之间权值和阀值调整
        for (int i = 0; i < hd_num; i++) {
            for (int j = 0; j < out_num; j++) {
                hd_out_weight[i][j] += ETA_W * delta[2][j] * out[1][i]; // 未加权值动量项
                /* 动量项
                 * hd_out_weight[i][j] += (ETA_A * hd_out_last[i][j] + ETA_W
                 * delta[2][j] * out[1][i]); hd_out_last[i][j] = ETA_A *
                 * hd_out_last[i][j] + ETA_W delta[2][j] * out[1][i];
                 */
            }

        }
        for (int i = 0; i < out_num; i++)
            hd_out_th[i] += ETA_T * delta[2][i];

        // 输入层和隐含层之间权值和阀值调整
        for (int i = 0; i < in_num; i++) {
            for (int j = 0; j < hd_num; j++) {
                in_hd_weight[i][j] += ETA_W * delta[1][j] * out[0][i]; // 未加权值动量项
                /* 动量项
                 * in_hd_weight[i][j] += (ETA_A * in_hd_last[i][j] + ETA_W
                 * delta[1][j] * out[0][i]); in_hd_last[i][j] = ETA_A *
                 * in_hd_last[i][j] + ETA_W delta[1][j] * out[0][i];
                 */
            }
        }
        for (int i = 0; i < hd_num; i++)
            in_hd_th[i] += ETA_T * delta[1][i];
    }

    // 符号函数sign
    public int Sign(double x) {
        if (x > 0)
            return 1;
        else if (x < 0)
            return -1;
        else
            return 0;
    }

    // 返回最大值
    public double Maximum(double x, double y) {
        if (x >= y)
            return x;
        else
            return y;
    }

    // 返回最小值
    public double Minimum(double x, double y) {
        if (x <= y)
            return x;
        else
            return y;
    }

    // log-sigmoid函数
    public double Sigmoid(double x) {
        return (double) (1 / (1 + Math.exp(-x)));
    }

    // log-sigmoid函数的倒数
    public double SigmoidDerivative(double y) {
        return (double) (y * (1 - y));
    }

    // tan-sigmoid函数
    public double TSigmoid(double x) {
        return (double) ((1 - Math.exp(-x)) / (1 + Math.exp(-x)));
    }

    // tan-sigmoid函数的倒数
    public double TSigmoidDerivative(double y) {
        return (double) (1 - (y * y));
    }

    // 分类预测函数
    public ArrayList<ArrayList<Double>> ForeCast(
            ArrayList<ArrayList<Double>> arraylist) {

        ArrayList<ArrayList<Double>> alloutlist = new ArrayList<>();
        ArrayList<Double> outlist = new ArrayList<>();
        int datanum = arraylist.size();
        for (int cnd = 0; cnd < datanum; cnd++) {
            for (int i = 0; i < in_num; i++)
                out[0][i] = arraylist.get(cnd).get(i); // 为输入节点赋值
            Forward();
            for (int i = 0; i < out_num; i++) {
                if (out[2][i] > 0 && out[2][i] < 0.5)
                    out[2][i] = 0;
                else if (out[2][i] > 0.5 && out[2][i] < 1) {
                    out[2][i] = 1;
                }
                outlist.add(out[2][i]);
            }
            alloutlist.add(outlist);
            outlist = new ArrayList<>();
            outlist.clear();
        }
        return alloutlist;
    }
    
}

DataUtil类主要用于进行数据的预处理以及训练和测试数据。

package bpnn;

/**
 * Data processing class to process training data and test data
 * @author Pumpkin
 * @
 */
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;

class DataUtil {

    private final ArrayList<ArrayList<Double>> alllist = new ArrayList<>(); // 存放所有数据
    private final ArrayList<String> outlist = new ArrayList<>(); // 存放输出数据,索引对应每个everylist的输出
    private final ArrayList<String> checklist = new ArrayList<>();  //存放测试集的真实输出字符串
    private int in_num = 0;
    private int out_num = 0; // 输入输出数据的个数
    private int type_num = 0; // 输出的类型数量
    private double[][] nom_data; //归一化输入数据中的最大值和最小值/*
    private int in_data_num = 0; //提前获得输入数据的个数*/

    // 获取输出类型的个数
    public int GetTypeNum() {
        return type_num;
    }

    // 设置输出类型的个数
    public void SetTypeNum(int type_num) {
        this.type_num = type_num;
    }

    // 获取输入数据的个数
    public int GetInNum() {
        return in_num;
    }

    // 获取输出数据的个数
    public int GetOutNum() {
        return out_num;
    }

    // 获取所有数据的数组
    public ArrayList<ArrayList<Double>> GetList() {
        return alllist;
    }

    // 获取输出为字符串形式的数据
    public ArrayList<String> GetOutList() {
        return outlist;
    }

    // 获取输出为字符串形式的数据
    public ArrayList<String> GetCheckList() {
        return checklist;
    }

    //返回归一化数据所需最大最小值
    public double[][] GetMaxMin(){

        return nom_data;
    }

    // 读取文件初始化数据
    public void ReadFile(String filepath, String sep, int flag)
            throws Exception {

        ArrayList<Double> everylist = new ArrayList<>(); // 存放每一组输入输出数据
        int readflag = flag; // flag=0,train;flag=1,test
        String encoding = "utf-8";
        File file = new File(filepath);
        if (file.isFile() && file.exists()) { // 判断文件是否存在
            InputStreamReader read = new InputStreamReader(new FileInputStream(
                    file), encoding);// 考虑到编码格式
            try (BufferedReader bufferedReader = new BufferedReader(read)) {
                String lineTxt = null;
                while ((lineTxt = bufferedReader.readLine()) != null) {
                    int in_number = 0;
                    String splits[] = lineTxt.split(sep); // 按','截取字符串
                    if (readflag == 0) {
                        for (int i = 0; i < splits.length; i++)
                            try {
                                everylist.add(Normalize(Double.valueOf(splits[i]),nom_data[i][0],nom_data[i][1]));
                                in_number++;
                            } catch (NumberFormatException e) {
                                if (!outlist.contains(splits[i]))
                                    outlist.add(splits[i]); // 存放字符串形式的输出数据
                                for (int k = 0; k < type_num; k++) {
                                    everylist.add(0.0);
                                }
                            //everylist.set(in_number + outlist.indexOf(splits[i]),1.0);
                            }
                    } else if (readflag == 1) {
                        for (int i = 0; i < splits.length; i++)
                            try {
                                everylist.add(Normalize(Double.valueOf(splits[i]),nom_data[i][0],nom_data[i][1]));
                                in_number++;
                            } catch (NumberFormatException e) {
                                checklist.add(splits[i]); // 存放字符串形式的输出数据
                            }
                    }
                    alllist.add(everylist); // 存放所有数据
                    in_num = in_number;
                    out_num = type_num;
                    everylist = new ArrayList<>();
                    everylist.clear();
                    
                }
            }
        }
    }

    //向文件写入分类结果
    public void WriteFile(String filepath, ArrayList<ArrayList<Double>> list, int in_number,  ArrayList<String> resultlist) throws IOException{
        File file = new File(filepath);
        FileWriter fw = null;
        BufferedWriter writer = null;
        try {
            fw = new FileWriter(file);
            writer = new BufferedWriter(fw);
            for(int i=0;i<list.size();i++){
                for(int j=0;j<in_number;j++)
                    writer.write(list.get(i).get(j)+",");
                writer.write(resultlist.get(i));
                writer.newLine();
            }
            writer.flush();
        } catch (IOException e) {
        }finally{
            writer.close();
            fw.close();
        }
    }


    //学习样本归一化,找到输入样本数据的最大值和最小值
    public void NormalizeData(String filepath) throws IOException{
        //提前获得输入数据的个数   
        GetBeforIn(filepath);
        int flag=1;
        nom_data = new double[in_data_num][2];
        String encoding = "utf-8";
        File file = new File(filepath);
        if (file.isFile() && file.exists()) { // 判断文件是否存在
            InputStreamReader read = new InputStreamReader(new FileInputStream(
                    file), encoding);// 考虑到编码格式
            try (BufferedReader bufferedReader = new BufferedReader(read)) {
                String lineTxt = null;
                while ((lineTxt = bufferedReader.readLine()) != null) {
                    String splits[] = lineTxt.split(","); // 按','截取字符串
                    for (int i = 0; i < splits.length-1; i++){
                        if(flag==1){
                            nom_data[i][0]=Double.valueOf(splits[i]);
                            nom_data[i][1]=Double.valueOf(splits[i]);
                        }
                        else{
                            if(Double.valueOf(splits[i])>nom_data[i][0])
                                nom_data[i][0]=Double.valueOf(splits[i]);
                            if(Double.valueOf(splits[i])<nom_data[i][1])
                                nom_data[i][1]=Double.valueOf(splits[i]);
                        }
                    }
                    flag=0;
                }
            }
        }
    }

    //归一化前获得输入数据的个数
    public void GetBeforIn(String filepath) throws IOException{
        String encoding = "utf-8";
        File file = new File(filepath);
        if (file.isFile() && file.exists()) { // 判断文件是否存在
            InputStreamReader read = new InputStreamReader(new FileInputStream(
                    file), encoding);// 考虑到编码格式
            try (BufferedReader beforeReader = new BufferedReader(read)) {
                String beforetext = beforeReader.readLine();
                String splits[] = beforetext.split(",");
                in_data_num = splits.length-1;
            }
        }
    }

    //归一化公式
    public double Normalize(double x, double max, double min){
        double y = 0.1+0.8*(x-min)/(max-min);
        return y;
    }
}

Test类,进行数据的读取和训练结果并显示。

package bpnn;

import java.util.ArrayList;

/**
 *
 * @author Pumpkin
 */


public class Test {
    public static void main(String args[]) throws Exception {

        ArrayList<ArrayList<Double>> alllist = new ArrayList<>(); // 存放所有数据
        ArrayList<String> outlist = new ArrayList<>(); // 存放分类的字符串
        int in_num = 0;
        int out_num = 0; // 输入输出数据的个数

        DataUtil dataUtil = new DataUtil(); // 初始化数据

        dataUtil.NormalizeData("C:\\Users\\童\\Desktop\\BPNN\\data\\train.txt");

        dataUtil.SetTypeNum(3); // 设置输出类型的数量
        dataUtil.ReadFile("C:\\Users\\童\\Desktop\\BPNN\\data\\train.txt", ",", 0);

        in_num = dataUtil.GetInNum(); // 获得输入数据的个数
        out_num = dataUtil.GetOutNum(); // 获得输出数据的个数(个数代表类型个数)
        alllist = dataUtil.GetList(); // 获得初始化后的数据

        outlist = dataUtil.GetOutList();
        System.out.print("Classification type:");
        for(int i =0 ;i<outlist.size();i++)
            System.out.print(outlist.get(i)+"  ");
        System.out.println();
        System.out.println("The number of training sets:"+alllist.size());

        BPNN bpnn = new BPNN();
        // 训练
        System.out.println("/***Train Start!***/");
        System.out.println(".............");
        bpnn.Train(in_num, out_num, alllist);
        System.out.println("/***Train End!***/");

        // 测试
        DataUtil testUtil = new DataUtil();

        testUtil.NormalizeData("C:\\Users\\童\\Desktop\\BPNN\\data\\test.txt");

        testUtil.SetTypeNum(3); // 设置输出类型的数量
        testUtil.ReadFile("C:\\Users\\童\\Desktop\\BPNN\\data\\test.txt", ",", 1);

        ArrayList<ArrayList<Double>> testList = new ArrayList<>();
        ArrayList<ArrayList<Double>> resultList = new ArrayList<>();
        ArrayList<String> normallist = new ArrayList<>(); // 存放测试集标准的输出字符串
        ArrayList<String> resultlist = new ArrayList<>(); // 存放测试集计算后的输出字符串

        double right = 0; // 分类正确的数量
        int type_num = 0; // 类型的数量
        double all_num = 0; //测试集的数量
        type_num = outlist.size();

        testList = testUtil.GetList(); // 获取测试数据
        normallist = testUtil.GetCheckList(); 

        int errorcount=0; // 分类错误的数量
        resultList = bpnn.ForeCast(testList); // 测试
        all_num=resultList.size();
        for (int i = 0; i < all_num; i++) {
            String checkString = "unknow";
            for (int j = 0; j < type_num-1; j++) {
                if(resultList.get(i).get(j)==1.0){
                    checkString = outlist.get(j);
                    resultlist.add(checkString);
                }
                /*else{
                    resultlist.add(checkString);
                }*/
            }
            /*
            if(checkString.equals("unknow"))
                errorcount++;
            */
            if(checkString.equals(normallist.get(i)))
                right++;
        }
        testUtil.WriteFile("C:\\Users\\童\\Desktop\\BPNN\\data\\result.txt",testList,in_num,resultlist);

        System.out.println("The number of test sets:"+ (new Double(all_num)).intValue());
        System.out.println("Classification correct quantity:"+(new Double(right)).intValue());
        System.out.println("The correct classification rate of the algorithm:"+right/all_num);

        System.out.println("Classification results are stored in:C:\\Users\\童\\Desktop\\BPNN\\data\\result.txt");      
    }
}

至此,简单的神经网络搭建已基本完成,开始你的表演吧。




评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值