bp神经网络的java实现


这两天开始研究BPNN。先阅读的这篇文章:http://www.codeproject.com/Articles/16508/AI-Neural-Network-for-beginners-Part-2-of-3;然后我把这篇文章里的代码按照面向对象的方式重写了一遍。在测试过程中发现一个比较奇怪的问题,两种实现的过程数据是一样的,但最后的计算结果却不一样,这个困扰我好几天了,我把代码贴出来,大家一起研究研究。
第一个类,神经元
public class Neuron {
    static int counter = 0;
    final public int id;  // auto increment, starts at 0
    NeuronConnection biasConnection;
    final double bias = -1;
    double output;
    double deltaOutput = 0;     
    List<NeuronConnection> Inconnections = new ArrayList<NeuronConnection>();
    Map<Integer, NeuronConnection> connMap = new HashMap<Integer, NeuronConnection>();
     
    public Neuron(){        
        id = counter;
        counter++;
    }
    /**
     * Compute Sj = Wij*Aij + w0j*bias
     */
    public void calculateOutput(){
        double s = 0;
        for(NeuronConnection con : Inconnections){
            Neuron leftNeuron = con.getFromNeuron();
            double weight = con.getWeight();
            double a = leftNeuron.getOutput(); //output from previous layer
             
            s = s + (weight*a);
        }
        s = s + (biasConnection.getWeight()*bias);     
        output = g(s);
    }
      
    double g(double x) {
        return sigmoid(x);
    }
 
    double sigmoid(double x) {
        return 1.0 / (1.0 +  (Math.exp(-x)));
    }
     
    public void addInConnectionsS(List<Neuron> inNeurons){
        for(Neuron n: inNeurons){
            NeuronConnection con = new NeuronConnection(n,this);
            Inconnections.add(con);
            connMap.put(n.id, con);
        }
    }
     
    public NeuronConnection getNeuronConn(int id){
        return connMap.get(id);
    }
 
    public void addInConnection(NeuronConnection con){
        Inconnections.add(con);
    }
    public void addBiasConnection(Neuron n){
        NeuronConnection con = new NeuronConnection(n,this);
        biasConnection = con;
        Inconnections.add(con);
        connMap.put(n.id, con);
    }
    public List<NeuronConnection> getAllInConnections(){
        return Inconnections;
    }
     
    public double getBias() {
        return bias;
    }
    public double getOutput() {
        return output;
    }
    public void setOutput(double o){
        output = o;
    }
}


第二个类,神经元之间的连接类
public class NeuronConnection {
    double weight = 0;
    double deltaWeight = 0;
    double prevDeltaWeight = 0; // for momentum
    
    
    final Neuron leftNeuron;
    final Neuron rightNeuron;
    static int counter = 0;
    final public int id; // auto increment, starts at 0
 
    public NeuronConnection(Neuron fromN, Neuron toN) {
        leftNeuron = fromN;
        rightNeuron = toN;
        id = counter;
        counter++;
    }
    
    public double getWeight() {
        return weight;
    }
 
    public void setWeight(double w) {
        weight = w;
    }
 
    public void setDeltaWeight(double w) {
        prevDeltaWeight = deltaWeight;
        deltaWeight = w;
    }
 
    public double getPrevDeltaWeight() {
        return prevDeltaWeight;
    }
 
    public Neuron getFromNeuron() {
        return leftNeuron;
    }
 
    public Neuron getToNeuron() {
        return rightNeuron;
    }
}


第三个类,神经网络
public class BPNN {

    private int[] layers;
    private Neuron bias = new Neuron();
    private double learningRate = 10f;
    private double momentum = 0.7f;
    private List<List<Neuron>> layerList = new ArrayList<List<Neuron>>();
    private Random random = new Random();
    
    public BPNN(int[] layers){
        this.layers = layers;
        init();
    }
    
    public BPNN(int[] layers, double learningRate, double momentum){
        this.layers = layers;
        this.learningRate = learningRate;
        this.momentum = momentum;
        init();
    }
    
    public String toString(){
        int i;
        List<Neuron> list0 = layerList.get(0);
        List<Neuron> list1 = layerList.get(layerList.size() - 1);
        for(i = 0; i < list0.size(); i++){
            System.out.print(list0.get(i).getOutput() + " ");
        }
        for(i = 0; i < list1.size(); i++){
            System.out.println(list1.get(i).getOutput() + " " + list1.get(i).deltaOutput);
        }
        return null;
    }
    
    private void init(){
        int i;
        /*初始化神经网络结构*/
        List<Neuron> list0 = new ArrayList<Neuron>();
        for(i = 0; i < layers[0]; i++){
            Neuron neuron = new Neuron();
            list0.add(neuron);
        }
        layerList.add(list0);
        for(int j = 1; j < layers.length; j++){
            List<Neuron> list = new ArrayList<Neuron>();
            for(i = 0; i < layers[j]; i++){
                Neuron neuron = new Neuron();
                neuron.addInConnectionsS(layerList.get(j - 1));
                neuron.addBiasConnection(bias);
                list.add(neuron);
            }
            layerList.add(list);
        }
        
        /*初始化权重*/
        for(List<Neuron> list : layerList){
            for(Neuron neuron : list){
                for(NeuronConnection conn : neuron.getAllInConnections()){
                    conn.setWeight(2 * random.nextDouble() - 1);
                }
            }
        }
        System.out.println();
    }
    
    /*接收输入参数*/
    public void setInput(double[] inputs){
        List<Neuron> list = layerList.get(0);
        if(list.size() != inputs.length){
            System.err.println("入参数量不对");
        }
        
        for(int i = 0; i < inputs.length; i++){
            list.get(i).setOutput(inputs[i]);
        }
    }
    
    /*前向计算输出,不计算输入层的输出*/
    public void forward(){
        int size = layerList.size();
        for(int i = 1; i < size; i++){
            List<Neuron> list = layerList.get(i);
            for(int j = 0; j < list.size(); j++){
                Neuron n = list.get(j);
                n.calculateOutput();
            }
        }
    }
    
    /*获取计算结果*/
    public double[] getOutput(){
        double[] ret = new double[layers[layers.length - 1]];
        List<Neuron> list = layerList.get(layerList.size() - 1);
        for(int i = 0; i < ret.length; i++){
            ret[i] = list.get(i).getOutput();
        }
        return ret;
    }
    
    /*根据误差反馈修正权重*/
    public void updateWeights(double[] target){
        double e = 1;
        if(target.length != layers[layers.length - 1]){
            System.err.println("期望输出数量错误");
        }
        int size = layerList.size();
        List<Neuron> outputLayer = layerList.get(size - 1);
        int i;
        double delta_output = 0;
        for(i = 0; i < outputLayer.size(); i++){
            Neuron n = outputLayer.get(i);
            double output = n.getOutput();
            e += target[i] - output;
            delta_output = output * (1 - output) * (target[i] - output);
            for(NeuronConnection conn : n.getAllInConnections()){
                double leftNoutput = conn.leftNeuron.getOutput();
                double delta_weight = this.learningRate * delta_output * leftNoutput;
                conn.setDeltaWeight(delta_weight);
                //this.momentum * conn.getPrevDeltaWeight()添加动量影响
                conn.setWeight(conn.getWeight() + delta_weight + this.momentum * conn.getPrevDeltaWeight());
            }
            n.deltaOutput = delta_output;
        }
        
        e = Math.abs(e - 1)/outputLayer.size();
        //TODO反馈修正
        for(i = size - 2; i == 1; i--){
            List<Neuron> hiddenLayer = layerList.get(i);
            List<Neuron> rightLayer = layerList.get(i + 1);
            int j;
            for (j = 0; j < hiddenLayer.size(); j++) {
                Neuron n = hiddenLayer.get(j);
                double output = n.getOutput();
                double error = 0;
                for(int h = 0; h < rightLayer.size(); h++){
                    Neuron rightN = rightLayer.get(h);
                    NeuronConnection conn = rightN.getNeuronConn(n.id);
                    error += rightN.deltaOutput * conn.getWeight();
                }
                n.deltaOutput = output * (1 - output) * error;
                for(int m = 0; m < n.getAllInConnections().size(); m++){
                    NeuronConnection conn = n.getAllInConnections().get(m);
                    double leftNOutput = conn.leftNeuron.getOutput();
                    double delta_weight = this.learningRate * n.deltaOutput * leftNOutput;
                    conn.setDeltaWeight(delta_weight);
                    conn.setWeight(conn.getWeight() + delta_weight + this.momentum * conn.getPrevDeltaWeight());
                }
            }
        }
    }
    
    /*对外训练接口*/
    public void train(double[] inputs, double[] output){
        setInput(inputs);
        forward();
        updateWeights(output);
    }
    
    /*获取计算结果*/
    public double[] result(double[] inputs){
        setInput(inputs);
        forward();
        return getOutput();
    }
    
    /*获取权值矩阵*/
    public List<NeuronConnection> getWeights(){
        List<NeuronConnection> ret = new ArrayList<NeuronConnection>();
        for(List<Neuron> list : layerList){
            for(Neuron neuron : list){
                for(NeuronConnection conn : neuron.getAllInConnections()){
                    ret.add(conn);
                }
            }
        }
        return ret;
    }
    
    public static void main(String[] args){
        int train = 100;
        BPNN n = new BPNN(new int[]{2,2,1});
        while(train-- > 0){
            n.train( new double[]{0,0}, new double[]{0});
            n.toString();
            n.train( new double[]{1,0}, new double[]{1});
            n.toString();
            n.train( new double[]{0,1}, new double[]{1});
            n.toString();
            n.train( new double[]{1,1}, new double[]{0});
            n.toString();
            System.out.println("****************************************");
        }
        for(double d : n.result( new double[]{0,0})){
            System.out.println(d);
        }
        for(double d : n.result( new double[]{1,0})){
            System.out.println(d);
        }
        for(double d : n.result( new double[]{0,1})){
            System.out.println(d);
        }
        for(double d : n.result( new double[]{1,1})){
            System.out.println(d);
        }
    }
}


请大家告诉我到底是哪个地方出现了问题,谢谢您指导。
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值