bp神经网络的java实现(2)

紧接上一篇文章,BPNN这个类,我重写了test方法,修改后代码如下:


package myBpnn;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;


/** 
 * <Description> <br> 
 *  
 * @author wang.yueyang<br>
 * @version 1.0<br>
 * @taskId <br>
 * @CreateDate Apr 18, 2013 <br>
 * @since V7.3<br>
 * @see myBpnn <br>
 */

public class BPNN {

    private int[] layers;
    private Neuron bias = new Neuron();
    private double learningRate = 10f;
    private double momentum = 0.7f;
    private List<List<Neuron>> layerList = new ArrayList<List<Neuron>>();
    private Random random = new Random();
    
    public BPNN(int[] layers){
        this.layers = layers;
        init();
    }
    
    public BPNN(int[] layers, double learningRate, double momentum){
        this.layers = layers;
        this.learningRate = learningRate;
        this.momentum = momentum;
        init();
    }
    
    public String toString(){
        int i;
        List<Neuron> list0 = layerList.get(0);
        List<Neuron> list1 = layerList.get(layerList.size() - 1);
        for(i = 0; i < list0.size(); i++){
            System.out.print(list0.get(i).getOutput() + " ");
        }
        for(i = 0; i < list1.size(); i++){
            System.out.println(list1.get(i).getOutput() + " " + list1.get(i).deltaOutput);
        }
        return null;
    }
    
    private void init(){
        int i;
        /*初始化神经网络结构*/
        List<Neuron> list0 = new ArrayList<Neuron>();
        for(i = 0; i < layers[0]; i++){
            Neuron neuron = new Neuron();
            list0.add(neuron);
        }
        layerList.add(list0);
        for(int j = 1; j < layers.length; j++){
            List<Neuron> list = new ArrayList<Neuron>();
            for(i = 0; i < layers[j]; i++){
                Neuron neuron = new Neuron();
                neuron.addInConnectionsS(layerList.get(j - 1));
                neuron.addBiasConnection(bias);
                list.add(neuron);
            }
            layerList.add(list);
        }
        
        /*初始化权重*/
        for(List<Neuron> list : layerList){
            for(Neuron neuron : list){
                for(NeuronConnection conn : neuron.getAllInConnections()){
                    conn.setWeight(2 * random.nextDouble() - 1);
                }
            }
        }
        System.out.println();
    }
    
    /*接收输入参数*/
    public void setInput(double[] inputs){
        List<Neuron> list = layerList.get(0);
        if(list.size() != inputs.length){
            System.err.println("入参数量不对");
        }
        
        for(int i = 0; i < inputs.length; i++){
            list.get(i).setOutput(inputs[i]);
        }
    }
    
    /*前向计算输出,不计算输入层的输出*/
    public void forward(){
        int size = layerList.size();
        for(int i = 1; i < size; i++){
            List<Neuron> list = layerList.get(i);
            for(int j = 0; j < list.size(); j++){
                Neuron n = list.get(j);
                n.calculateOutput();
            }
        }
    }
    
    /*获取计算结果*/
    public double[] getOutput(){
        double[] ret = new double[layers[layers.length - 1]];
        List<Neuron> list = layerList.get(layerList.size() - 1);
        for(int i = 0; i < ret.length; i++){
            ret[i] = list.get(i).getOutput();
        }
        return ret;
    }
    
    /*根据误差反馈修正权重*/
    public void updateWeights(double[] target){
        double e = 1;
        if(target.length != layers[layers.length - 1]){
            System.err.println("期望输出数量错误");
        }
        int size = layerList.size();
        List<Neuron> outputLayer = layerList.get(size - 1);
        int i;
        double delta_output = 0;
        for(i = 0; i < outputLayer.size(); i++){
            Neuron n = outputLayer.get(i);
            double output = n.getOutput();
            e += target[i] - output;
            delta_output = output * (1 - output) * (target[i] - output);
            for(NeuronConnection conn : n.getAllInConnections()){
                double leftNoutput = conn.leftNeuron.getOutput();
                double delta_weight = this.learningRate * delta_output * leftNoutput;
                conn.setDeltaWeight(delta_weight);
                //this.momentum * conn.getPrevDeltaWeight()添加动量影响
                conn.setWeight(conn.getWeight() + delta_weight + this.momentum * conn.getPrevDeltaWeight());
            }
            n.deltaOutput = delta_output;
        }
        
        e = Math.abs(e - 1)/outputLayer.size();
        //TODO反馈修正
        for(i = size - 2; i == 1; i--){
            List<Neuron> hiddenLayer = layerList.get(i);
            List<Neuron> rightLayer = layerList.get(i + 1);
            int j;
            for (j = 0; j < hiddenLayer.size(); j++) {
                Neuron n = hiddenLayer.get(j);
                double output = n.getOutput();
                double error = 0;
                for(int h = 0; h < rightLayer.size(); h++){
                    Neuron rightN = rightLayer.get(h);
                    NeuronConnection conn = rightN.getNeuronConn(n.id);
                    error += rightN.deltaOutput * conn.getWeight();
                }
                n.deltaOutput = output * (1 - output) * error;
                for(int m = 0; m < n.getAllInConnections().size(); m++){
                    NeuronConnection conn = n.getAllInConnections().get(m);
                    double leftNOutput = conn.leftNeuron.getOutput();
                    double delta_weight = this.learningRate * n.deltaOutput * leftNOutput;
                    conn.setDeltaWeight(delta_weight);
                    conn.setWeight(conn.getWeight() + delta_weight + this.momentum * conn.getPrevDeltaWeight());
                }
            }
        }
    }
    
    /*对外训练接口*/
    public void train(double[] inputs, double[] output){
        setInput(inputs);
        forward();
        updateWeights(output);
    }
    
    /*获取计算结果*/
    public double[] result(double[] inputs){
        setInput(inputs);
        forward();
        return getOutput();
    }
    
    /*获取权值矩阵*/
    public List<NeuronConnection> getWeights(){
        List<NeuronConnection> ret = new ArrayList<NeuronConnection>();
        for(List<Neuron> list : layerList){
            for(Neuron neuron : list){
                for(NeuronConnection conn : neuron.getAllInConnections()){
                    ret.add(conn);
                }
            }
        }
        return ret;
    }
    
    public void test(){
        /*int train = 1000;
        while(train-- > 0){
            train( new double[]{0,0}, new double[]{0});
            train( new double[]{1,0}, new double[]{1});
            train( new double[]{0,1}, new double[]{1});
            train( new double[]{1,1}, new double[]{0});
        }
        for(double d : result( new double[]{0,0})){
            System.out.println(d);
        }
        for(double d : result( new double[]{1,0})){
            System.out.println(d);
        }
        for(double d : result( new double[]{0,1})){
            System.out.println(d);
        }
        for(double d : result( new double[]{1,1})){
            System.out.println(d);
        }*/
        double[][] train_set = {{0, 0},{0, 1},{1, 0},{1, 1}};
        double[] target = new double[1];
        B b = new B(new int[]{2,10,1});
        for(int j = 0; j < 1000; j++){
            int i = 0;
            for(double[] set : train_set){
                if(i%4 == 0 || i%4 == 3){
                    target[0] = 0;
                }else{
                    target[0] = 1;
                }
                i++;
                b.train(set, target);
            }
        }
        //System.out.println("*******************************************************");
        for(double[] set : train_set){
            b.setInput(set);
            b.forward();
            for(double g : b.getOutput()){
                System.out.println(g);
            }
        }
    }
    
    public static void main(String[] args){
        int i = 100;
        while(i-- > 0){
            BPNN b = new BPNN(new int[]{2,10,1});
            b.test();
            System.out.println("************************************************");
        }
    }
}


在test()方法里面,如果用注释的代码替代现有代码,就会导致在这100次的测试中,中间会有计算不稳定的时候,比如计算1xor1会接近0;而用现有代码做测试,则100次全正常。实在搞不明白,come on!

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值