步骤:
- 下载JAR包 ,我用的neuroph-2.98.zip,注意不是neurophstudio-windows-2.98.exe;
- 建立Java工程;
- 添加引用,这个例子添加如下四个引用即可;
- 对例子代码稍做修改:添加迭代日志输入,以及添加集合输入测试;
代码如下:
import java.util.Arrays;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.nnet.Perceptron;
import org.neuroph.nnet.learning.BinaryDeltaRule;
public class MyPerceptron implements LearningEventListener{
public static void main(String[] args) {
new MyPerceptron().TestPerceptron();
}
public void TestPerceptron() {
// create new perceptron network
NeuralNetwork myperceptron = new Perceptron(2, 1);
// create training set AND
DataSet trainingSet = new DataSet(2, 1);
trainingSet.add(new double[]{0, 0}, new double[]{0});
trainingSet.add(new double[]{0, 1}, new double[]{0});
trainingSet.add(new double[]{1, 0}, new double[]{0});
trainingSet.add(new double[]{1, 1}, new double[]{1});
// print iteration log
LearningRule lr = myperceptron.getLearningRule();
lr.addListener(this);
// learn the training set
myperceptron.learn(trainingSet);
// save the trained network into file
myperceptron.save("or_perceptron_nnet");
// load the saved network
NeuralNetwork neuralNetworkTest = NeuralNetwork.createFromFile("or_perceptron_nnet");
// Test network - Single input
neuralNetworkTest.setInput(0, 0);
neuralNetworkTest.calculate();
double[] networkOutput = neuralNetworkTest.getOutput();
System.out.println(networkOutput[0]);
// Test network - set input
testNeuralNetwork(myperceptron, trainingSet);
}
public void testNeuralNetwork(NeuralNetwork nnet, DataSet tset) {
for (DataSetRow dataRow : tset.getRows()) {
nnet.setInput(dataRow.getInput());
nnet.calculate();
double[ ] networkOutput = nnet.getOutput();
System.out.print("Input: " + Arrays.toString(dataRow.getInput()) );
System.out.println(" Output: " + Arrays.toString(networkOutput) );
}
}
@Override
public void handleLearningEvent(LearningEvent event) {
BinaryDeltaRule rule = (BinaryDeltaRule)event.getSource();
if (event.getEventType() != LearningEvent.Type.LEARNING_STOPPED)
System.out.println(rule.getCurrentIteration() + ". iteration : "+ rule.getTotalNetworkError());
}
}
参考:
-
Getting Started with Neuroph 2.98.pdf 【下载ZIP包自带】
-
《神经网络与深度学习》
我只能说对于我这样的JAVA小白,这两个材料都太不友好,不是引用包少写一个就是代码残缺不全,看似简单的东西搞的一步一个坑,确实有必要记录一下,希望能帮到其他人。