1.创建增加学习算法的感知机
/**
* @author Ragty
* @param 增加学习算法的感知机(记忆逻辑与)
* @serialData 2018.4.22
* @param inputNeuralCount
*/
public void creatPerceptron(int inputNeuralCount){
//设置类型为感知机
this.setNetworkType(NeuralNetworkType.PERCEPTRON);
//建立输入神经元,表示输入刺激
NeuronProperties inputNeuronProperties = new NeuronProperties();
inputNeuronProperties.setProperty("neuronType",InputNeuron.class);
//建立输入层
Layer inputLayer = LayerFactory.createLayer(inputNeuralCount, inputNeuronProperties);
this.addLayer(inputLayer);
inputLayer.addNeuron(new BiasNeuron());
//建立输出神经元(传输函数为step)
NeuronProperties outputNeuronProperties = new NeuronProperties();
outputNeuronProperties.setProperty("transferFunction", TransferFunctionType.STEP);
//建立输出层
Layer outputLayer = LayerFactory.createLayer(1, outputNeuronProperties);
this.addLayer(outputLayer);
//输入层输出层全连接
ConnectionFactory.fullConnect(inputLayer, outputLayer);
NeuralNetworkFactory.setDefaultIO(this);
//设置感知机学习算法
this.setLearningRule(new perceptronLearningRule());
}
2.创建学习算法
public class perceptronLearningRule extends SupervisedLearning implements Serializable{
private static final long serialVersionUID = 1L;
public perceptronLearningRule() {
}
/**
* @author Ragty
* @param 迭代计算权值
* @serialData 2018.4.22
*/
@Override
protected void updateNetworkWeights(double[] outputError) {
int i = 0;
for (Neuron neuron : neuralNetwork.getOutputNeurons()) {
neuron.setError(outputError[i]);
double neuronError = neuron.getError();
// 根据所有的神经元输入 迭代学习
for (Connection connection : neuron.getInputConnections()) {
// 神经元的一个输入
double input = connection.getInput();
// 计算权值的变更
double weightChange = neuronError * input;
// 更新权值
Weight weight = connection.getWeight();
weight.weightChange = weightChange;
weight.value += weightChange;
}
i++;
}
}
}
public class AndPerceptron implements LearningEventListener{
public static void main(String[] args) {
new AndPerceptron().run();
}
public void run(){
//给出学习的训练数据(用于训练神经网络)
//数据集有两个输入,一个输出
//dataSetRow的构造函数接受两个参数,第一个为输入向量,第二个为期望值
DataSet trainningSet = new DataSet(2,1);
trainningSet.addRow(new DataSetRow(new double[]{0,0},new double[]{0}));
trainningSet.addRow(new DataSetRow(new double[]{0,1},new double[]{0}));
trainningSet.addRow(new DataSetRow(new double[]{1,0},new double[]{0}));
trainningSet.addRow(new DataSetRow(new double[]{1,1},new double[]{1}));
//创建一个只有两个输入节点的感知机
simplePerceptron andPerceptron = new simplePerceptron(2);
//给学习过程增加事件监听器(监督训练)
perceptronLearningRule learningRule = (perceptronLearningRule) andPerceptron.getLearningRule();
learningRule.addListener(this);
//使用训练数据训练感知机(进行学习)
System.out.println("训练开始");
andPerceptron.learn(trainningSet);
//测试感知机是否能正确输出
System.out.println("测试输出");
testNeuralNetwork(andPerceptron, trainningSet);
}
/**
* @author Ragty
* @param 训练之后对网络测试(测试感知机)
* @serialData 2018.4.22
* @param neuralNetwork
* @param data
*/
public static void testNeuralNetwork(NeuralNetwork neuralNetwork, DataSet testSet){
for(DataSetRow testSetRow : testSet.getRows()){
neuralNetwork.setInput(testSetRow.getInput());
neuralNetwork.calculate();
double[] networkOutput = neuralNetwork.getOutput();
System.out.println("Input:"+Arrays.toString(testSetRow.getInput()));
System.out.println("Output:"+Arrays.toString(networkOutput));
}
}
//监督训练过程
@Override
public void handleLearningEvent(LearningEvent event) {
// TODO Auto-generated method stub
//所有迭代学习算法的基类, 它为它的所有子类提供迭代学习过程
IterativeLearning bp = (IterativeLearning) event.getSource();
System.out.println("iterate:"+bp.getCurrentIteration());
System.out.println(Arrays.toString(bp.getNeuralNetwork().getWeights()));
}
}