多层神经网络解决XOR问题

import java.util.Arrays;

import org.junit.Test;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.util.TransferFunctionType;

public class MultiPerceptron extends NeuralNetwork implements LearningEventListener {

	private static final long serialVersionUID = 1L;

	@Test
	public void test() {
		DataSet trainingSet = new DataSet(2, 1);
		
		trainingSet.add(new DataSetRow(new double[] { 0, 0 }, new double[] { 0 }));
		trainingSet.add(new DataSetRow(new double[] { 0, 1 }, new double[] { 1 }));
		trainingSet.add(new DataSetRow(new double[] { 1, 0 }, new double[] { 1 }));
		trainingSet.add(new DataSetRow(new double[] { 1, 1 }, new double[] { 0 }));


        //创建多层感知机,输入层2个神经元,隐含层3个神经元,最后输出层为1个隐含神经元,使用TANH传输函数用于最后格式化的输出
		MultiLayerPerceptron myMultiPerceptron = new MultiLayerPerceptron(TransferFunctionType.TANH,2,2,1);
		
		//在MultiLayerPerceptron中this.setLearningRule(new MomentumBackpropagation());
		LearningRule lr = myMultiPerceptron.getLearningRule();
		lr.addListener(this);
		
		//开始训练
		System.out.println("Training neural network ...");
		myMultiPerceptron.learn(trainingSet);
		
		//必须写在learn()上边
		//LearningRule lr = myMultiPerceptron.getLearningRule();
		//lr.addListener(this);
		
		testNeuralNetwork(myMultiPerceptron,trainingSet);
	}


	public static void testNeuralNetwork(NeuralNetwork nnet, DataSet test) {

		for (DataSetRow dataRow : test.getRows()) {

			nnet.setInput(dataRow.getInput());
			nnet.calculate();
			double[] networkOutput = nnet.getOutput();
			System.out.print("Input: " + Arrays.toString(dataRow.getInput()));
			System.out.println(" Output: " + Arrays.toString(networkOutput));
		}
	}

	@Override
	public void handleLearningEvent(LearningEvent event) {
		SupervisedLearning bp = (SupervisedLearning) event.getSource();
		if (event.getEventType() != LearningEvent.Type.LEARNING_STOPPED) {
			System.out.println(bp.getCurrentIteration() + ". iteration : " + bp.getTotalNetworkError());
		}
	}
}

注意:绑定监听要在学习训练方法的上边否则监听无效

课本采用的2.7版本的jar包,监听里没有Enum LearningEvent.Type,但是实测不需要if (event.getEventType() != LearningEvent.Type.LEARNING_STOPPED)这个判断条件,learn()方法结束监听也会自动停止,这也是为什么监听要放到learn()方法语句上边
在这里插入图片描述
在这里插入图片描述


下边放个2.7版本下的代码,主要就是训练集添加的方法要改成addRow(),然后监听的判断条件要改

import java.util.Arrays;

import org.junit.Test;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.DataSet;
import org.neuroph.core.learning.DataSetRow;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.util.TransferFunctionType;

public class MultiPerceptron extends NeuralNetwork implements LearningEventListener {

	private static final long serialVersionUID = 1L;

	@Test
	public void test() {
		DataSet trainingSet = new DataSet(2, 1);
		
		trainingSet.addRow(new DataSetRow(new double[] { 0, 0 }, new double[] { 0 }));
		trainingSet.addRow(new DataSetRow(new double[] { 0, 1 }, new double[] { 1 }));
		trainingSet.addRow(new DataSetRow(new double[] { 1, 0 }, new double[] { 1 }));
		trainingSet.addRow(new DataSetRow(new double[] { 1, 1 }, new double[] { 0 }));


        //创建多层感知机,输入层2个神经元,隐含层3个神经元,最后输出层为1个隐含神经元,使用TANH传输函数用于最后格式化的输出
		MultiLayerPerceptron myMultiPerceptron = new MultiLayerPerceptron(TransferFunctionType.TANH,2,3,1);
		
		LearningRule lr = myMultiPerceptron.getLearningRule();
		lr.addListener(this);
		
		//开始训练
		System.out.println("Training neural network ...");
		myMultiPerceptron.learn(trainingSet);
		
		testNeuralNetwork(myMultiPerceptron,trainingSet);

	}
	
	public static void testNeuralNetwork(NeuralNetwork nnet, DataSet test) {

		for (DataSetRow dataRow : test.getRows()) {

			nnet.setInput(dataRow.getInput());
			nnet.calculate();
			double[] networkOutput = nnet.getOutput();
			System.out.print("Input: " + Arrays.toString(dataRow.getInput()));
			System.out.println(" Output: " + Arrays.toString(networkOutput));
		}
	}

	@Override
	public void handleLearningEvent(LearningEvent arg0) {
		SupervisedLearning bp = (SupervisedLearning) arg0.getSource();
		System.out.println(bp.getCurrentIteration() + ". iteration : " + bp.getTotalNetworkError());
	}
	
	
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值