import java.util.Arrays;
import org.junit.Test;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.util.TransferFunctionType;
public class MultiPerceptron extends NeuralNetwork implements LearningEventListener {
private static final long serialVersionUID = 1L;
@Test
public void test() {
DataSet trainingSet = new DataSet(2, 1);
trainingSet.add(new DataSetRow(new double[] { 0, 0 }, new double[] { 0 }));
trainingSet.add(new DataSetRow(new double[] { 0, 1 }, new double[] { 1 }));
trainingSet.add(new DataSetRow(new double[] { 1, 0 }, new double[] { 1 }));
trainingSet.add(new DataSetRow(new double[] { 1, 1 }, new double[] { 0 }));
//创建多层感知机,输入层2个神经元,隐含层3个神经元,最后输出层为1个隐含神经元,使用TANH传输函数用于最后格式化的输出
MultiLayerPerceptron myMultiPerceptron = new MultiLayerPerceptron(TransferFunctionType.TANH,2,2,1);
//在MultiLayerPerceptron中this.setLearningRule(new MomentumBackpropagation());
LearningRule lr = myMultiPerceptron.getLearningRule();
lr.addListener(this);
//开始训练
System.out.println("Training neural network ...");
myMultiPerceptron.learn(trainingSet);
//必须写在learn()上边
//LearningRule lr = myMultiPerceptron.getLearningRule();
//lr.addListener(this);
testNeuralNetwork(myMultiPerceptron,trainingSet);
}
public static void testNeuralNetwork(NeuralNetwork nnet, DataSet test) {
for (DataSetRow dataRow : test.getRows()) {
nnet.setInput(dataRow.getInput());
nnet.calculate();
double[] networkOutput = nnet.getOutput();
System.out.print("Input: " + Arrays.toString(dataRow.getInput()));
System.out.println(" Output: " + Arrays.toString(networkOutput));
}
}
@Override
public void handleLearningEvent(LearningEvent event) {
SupervisedLearning bp = (SupervisedLearning) event.getSource();
if (event.getEventType() != LearningEvent.Type.LEARNING_STOPPED) {
System.out.println(bp.getCurrentIteration() + ". iteration : " + bp.getTotalNetworkError());
}
}
}
注意:绑定监听要在学习训练方法的上边否则监听无效
课本采用的2.7版本的jar包,监听里没有Enum LearningEvent.Type,但是实测不需要if (event.getEventType() != LearningEvent.Type.LEARNING_STOPPED)
这个判断条件,learn()方法结束监听也会自动停止,这也是为什么监听要放到learn()方法语句上边
下边放个2.7版本下的代码,主要就是训练集添加的方法要改成addRow(),然后监听的判断条件要改
import java.util.Arrays;
import org.junit.Test;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.DataSet;
import org.neuroph.core.learning.DataSetRow;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.util.TransferFunctionType;
public class MultiPerceptron extends NeuralNetwork implements LearningEventListener {
private static final long serialVersionUID = 1L;
@Test
public void test() {
DataSet trainingSet = new DataSet(2, 1);
trainingSet.addRow(new DataSetRow(new double[] { 0, 0 }, new double[] { 0 }));
trainingSet.addRow(new DataSetRow(new double[] { 0, 1 }, new double[] { 1 }));
trainingSet.addRow(new DataSetRow(new double[] { 1, 0 }, new double[] { 1 }));
trainingSet.addRow(new DataSetRow(new double[] { 1, 1 }, new double[] { 0 }));
//创建多层感知机,输入层2个神经元,隐含层3个神经元,最后输出层为1个隐含神经元,使用TANH传输函数用于最后格式化的输出
MultiLayerPerceptron myMultiPerceptron = new MultiLayerPerceptron(TransferFunctionType.TANH,2,3,1);
LearningRule lr = myMultiPerceptron.getLearningRule();
lr.addListener(this);
//开始训练
System.out.println("Training neural network ...");
myMultiPerceptron.learn(trainingSet);
testNeuralNetwork(myMultiPerceptron,trainingSet);
}
public static void testNeuralNetwork(NeuralNetwork nnet, DataSet test) {
for (DataSetRow dataRow : test.getRows()) {
nnet.setInput(dataRow.getInput());
nnet.calculate();
double[] networkOutput = nnet.getOutput();
System.out.print("Input: " + Arrays.toString(dataRow.getInput()));
System.out.println(" Output: " + Arrays.toString(networkOutput));
}
}
@Override
public void handleLearningEvent(LearningEvent arg0) {
SupervisedLearning bp = (SupervisedLearning) arg0.getSource();
System.out.println(bp.getCurrentIteration() + ". iteration : " + bp.getTotalNetworkError());
}
}