ann xor


/**
* Network
* Copyright 2005 by Jeff Heaton(jeff@jeffheaton.com)
*
* Example program from Chapter 3
* Programming Neural Networks in Java
* http://www.heatonresearch.com/articles/series/1/
*
* This software is copyrighted. You may use it in programs
* of your own, without restriction, but you may not
* publish the source code without the author's permission.
* For more information on distributing this code, please
* visit:
* http://www.heatonresearch.com/hr_legal.php
*
* @author Jeff Heaton
* @version 1.1
*/

public class Network {

/**
* The global error for the training.
*/
protected double globalError;


/**
* The number of input neurons.
*/
protected int inputCount;

/**
* The number of hidden neurons.
*/
protected int hiddenCount;

/**
* The number of output neurons
*/
protected int outputCount;

/**
* The total number of neurons in the network.
*/
protected int neuronCount;

/**
* The number of weights in the network.
*/
protected int weightCount;

/**
* The learning rate.
*/
protected double learnRate;

/**
* The outputs from the various levels.
*/
protected double fire[];

/**
* The weight matrix this, along with the thresholds can be
* thought of as the "memory" of the neural network.
*/
protected double matrix[];

/**
* The errors from the last calculation.
*/
protected double error[];

/**
* Accumulates matrix delta's for training.
*/
protected double accMatrixDelta[];

/**
* The thresholds, this value, along with the weight matrix
* can be thought of as the memory of the neural network.
*/
protected double thresholds[];

/**
* The changes that should be applied to the weight
* matrix.
*/
protected double matrixDelta[];

/**
* The accumulation of the threshold deltas.
*/
protected double accThresholdDelta[];

/**
* The threshold deltas.
*/
protected double thresholdDelta[];

/**
* The momentum for training.
*/
protected double momentum;

/**
* The changes in the errors.
*/
protected double errorDelta[];


/**
* Construct the neural network.
*
* @param inputCount The number of input neurons.
* @param hiddenCount The number of hidden neurons
* @param outputCount The number of output neurons
* @param learnRate The learning rate to be used when training.
* @param momentum The momentum to be used when training.
*/
public Network(int inputCount,
int hiddenCount,
int outputCount,
double learnRate,
double momentum) {

this.learnRate = learnRate;
this.momentum = momentum;

this.inputCount = inputCount;
this.hiddenCount = hiddenCount;
this.outputCount = outputCount;
neuronCount = inputCount + hiddenCount + outputCount;
weightCount = (inputCount * hiddenCount) + (hiddenCount * outputCount);

fire = new double[neuronCount];
matrix = new double[weightCount];
matrixDelta = new double[weightCount];
thresholds = new double[neuronCount];
errorDelta = new double[neuronCount];
error = new double[neuronCount];
accThresholdDelta = new double[neuronCount];
accMatrixDelta = new double[weightCount];
thresholdDelta = new double[neuronCount];

reset();
}


/**
* Returns the root mean square error for a complet training set.
*
* @param len The length of a complete training set.
* @return The current error for the neural network.
*/
public double getError(int len) {
double err = Math.sqrt(globalError / (len * outputCount));
globalError = 0; // clear the accumulator
return err;

}

/**
* The threshold method. You may wish to override this class to provide other
* threshold methods.
*
* @param sum The activation from the neuron.
* @return The activation applied to the threshold method.
*/
public double threshold(double sum) {
return 1.0 / (1 + Math.exp(-1.0 * sum));
}

/**
* Compute the output for a given input to the neural network.
*
* @param input The input provide to the neural network.
* @return The results from the output neurons.
*/
public double []computeOutputs(double input[]) {
int i, j;
final int hiddenIndex = inputCount;
final int outIndex = inputCount + hiddenCount;

for (i = 0; i < inputCount; i++) {
fire[i] = input[i];
}

// first layer
int inx = 0;

for (i = hiddenIndex; i < outIndex; i++) {
double sum = thresholds[i];

for (j = 0; j < inputCount; j++) {
sum += fire[j] * matrix[inx++];
}
fire[i] = threshold(sum);
}

// hidden layer

double result[] = new double[outputCount];

for (i = outIndex; i < neuronCount; i++) {
double sum = thresholds[i];

for (j = hiddenIndex; j < outIndex; j++) {
sum += fire[j] * matrix[inx++];
}
fire[i] = threshold(sum);
result[i-outIndex] = fire[i];
}

return result;
}


/**
* Calculate the error for the recogntion just done.
*
* @param ideal What the output neurons should have yielded.
*/
public void calcError(double ideal[]) {
int i, j;
final int hiddenIndex = inputCount;
final int outputIndex = inputCount + hiddenCount;

// clear hidden layer errors
for (i = inputCount; i < neuronCount; i++) {
error[i] = 0;
}

// layer errors and deltas for output layer
for (i = outputIndex; i < neuronCount; i++) {
error[i] = ideal[i - outputIndex] - fire[i];
globalError += error[i] * error[i];
errorDelta[i] = error[i] * fire[i] * (1 - fire[i]);
}

// hidden layer errors
int winx = inputCount * hiddenCount;

for (i = outputIndex; i < neuronCount; i++) {
for (j = hiddenIndex; j < outputIndex; j++) {
accMatrixDelta[winx] += errorDelta[i] * fire[j];
error[j] += matrix[winx] * errorDelta[i];
winx++;
}
accThresholdDelta[i] += errorDelta[i];
}

// hidden layer deltas
for (i = hiddenIndex; i < outputIndex; i++) {
errorDelta[i] = error[i] * fire[i] * (1 - fire[i]);
}

// input layer errors
winx = 0; // offset into weight array
for (i = hiddenIndex; i < outputIndex; i++) {
for (j = 0; j < hiddenIndex; j++) {
accMatrixDelta[winx] += errorDelta[i] * fire[j];
error[j] += matrix[winx] * errorDelta[i];
winx++;
}
accThresholdDelta[i] += errorDelta[i];
}
}

/**
* Modify the weight matrix and thresholds based on the last call to
* calcError.
*/
public void learn() {
int i;

// process the matrix
for (i = 0; i < matrix.length; i++) {
matrixDelta[i] = (learnRate * accMatrixDelta[i]) + (momentum * matrixDelta[i]);
matrix[i] += matrixDelta[i];
accMatrixDelta[i] = 0;
}

// process the thresholds
for (i = inputCount; i < neuronCount; i++) {
thresholdDelta[i] = learnRate * accThresholdDelta[i] + (momentum * thresholdDelta[i]);
thresholds[i] += thresholdDelta[i];
accThresholdDelta[i] = 0;
}
}

/**
* Reset the weight matrix and the thresholds.
*/
public void reset() {
int i;

for (i = 0; i < neuronCount; i++) {
thresholds[i] = 0.5 - (Math.random());
thresholdDelta[i] = 0;
accThresholdDelta[i] = 0;
}
for (i = 0; i < matrix.length; i++) {
matrix[i] = 0.5 - (Math.random());
matrixDelta[i] = 0;
accMatrixDelta[i] = 0;
}
}
}




import javax.swing.*;
import java.awt.*;
import java.awt.event.*;
import java.text.*;

/**
* XorExample
* Copyright 2005 by Jeff Heaton(jeff@jeffheaton.com)
*
* Example program from Chapter 3
* Programming Neural Networks in Java
* http://www.heatonresearch.com/articles/series/1/
*
* This software is copyrighted. You may use it in programs
* of your own, without restriction, but you may not
* publish the source code without the author's permission.
* For more information on distributing this code, please
* visit:
* http://www.heatonresearch.com/hr_legal.php
*
* @author Jeff Heaton
* @version 1.1
*/
public class XorExample extends JFrame implements
ActionListener,Runnable {

/**
* The train button.
*/
JButton btnTrain;

/**
* The run button.
*/
JButton btnRun;

/**
* The quit button.
*/
JButton btnQuit;

/**
* The status line.
*/
JLabel status;

/**
* The background worker thread.
*/
protected Thread worker = null;

/**
* The number of input neurons.
*/
protected final static int NUM_INPUT = 2;

/**
* The number of output neurons.
*/
protected final static int NUM_OUTPUT = 1;

/**
* The number of hidden neurons.
*/
protected final static int NUM_HIDDEN = 3;

/**
* The learning rate.
*/
protected final static double RATE = 0.5;

/**
* The learning momentum.
*/
protected final static double MOMENTUM = 0.7;


/**
* The training data that the user enters.
* This represents the inputs and expected
* outputs for the XOR problem.
*/
protected JTextField data[][] = new JTextField[4][4];

/**
* The neural network.
*/
protected Network network;


/**
* Constructor. Setup the components.
*/
public XorExample()
{
setTitle("XOR Solution");
network = new Network(
NUM_INPUT,
NUM_HIDDEN,
NUM_OUTPUT,
RATE,
MOMENTUM);

Container content = getContentPane();

GridBagLayout gridbag = new GridBagLayout();
GridBagConstraints c = new GridBagConstraints();
content.setLayout(gridbag);

c.fill = GridBagConstraints.NONE;
c.weightx = 1.0;

// Training input label
c.gridwidth = GridBagConstraints.REMAINDER; //end row
c.anchor = GridBagConstraints.NORTHWEST;
content.add(
new JLabel(
"Enter training data:"),c);

JPanel grid = new JPanel();
grid.setLayout(new GridLayout(5,4));
grid.add(new JLabel("IN1"));
grid.add(new JLabel("IN2"));
grid.add(new JLabel("Expected OUT "));
grid.add(new JLabel("Actual OUT"));

for ( int i=0;i<4;i++ ) {
int x = (i&1);
int y = (i&2)>>1;
grid.add(data[i][0] = new JTextField(""+y));
grid.add(data[i][1] = new JTextField(""+x));
grid.add(data[i][2] = new JTextField(""+(x^y)));
grid.add(data[i][3] = new JTextField("??"));
data[i][0].setEditable(false);
data[i][1].setEditable(false);
data[i][3].setEditable(false);
}

content.add(grid,c);

// the button panel
JPanel buttonPanel = new JPanel(new FlowLayout());
buttonPanel.add(btnTrain = new JButton("Train"));
buttonPanel.add(btnRun = new JButton("Run"));
buttonPanel.add(btnQuit = new JButton("Quit"));
btnTrain.addActionListener(this);
btnRun.addActionListener(this);
btnQuit.addActionListener(this);

// Add the button panel
c.gridwidth = GridBagConstraints.REMAINDER; //end row
c.anchor = GridBagConstraints.CENTER;
content.add(buttonPanel,c);

// Training input label
c.gridwidth = GridBagConstraints.REMAINDER; //end row
c.anchor = GridBagConstraints.NORTHWEST;
content.add(
status = new JLabel("Click train to begin training..."),c);

// adjust size and position
pack();
Toolkit toolkit = Toolkit.getDefaultToolkit();
Dimension d = toolkit.getScreenSize();
setLocation(
(int)(d.width-this.getSize().getWidth())/2,
(int)(d.height-this.getSize().getHeight())/2 );
setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
setResizable(false);

btnRun.setEnabled(false);
}

/**
* The main function, just display the JFrame.
*
* @param args No arguments are used.
*/
public static void main(String args[])
{
(new XorExample()).show(true);
}

/**
* Called when the user clicks one of the three
* buttons.
*
* @param e The event.
*/
public void actionPerformed(ActionEvent e)
{
if ( e.getSource()==btnQuit )
System.exit(0);
else if ( e.getSource()==btnTrain )
train();
else if ( e.getSource()==btnRun )
evaluate();
}

/**
* Called when the user clicks the run button.
*/
protected void evaluate()
{
double xorData[][] = getGrid();
int update=0;

for (int i=0;i<4;i++) {
NumberFormat nf = NumberFormat.getInstance();
double d[] = network.computeOutputs(xorData[i]);
data[i][3].setText(nf.format(d[0]));
}

}


/**
* Called when the user clicks the train button.
*/
protected void train()
{
if ( worker != null )
worker = null;
worker = new Thread(this);
worker.setPriority(Thread.MIN_PRIORITY);
worker.start();
}

/**
* The thread worker, used for training
*/
public void run()
{
double xorData[][] = getGrid();
double xorIdeal[][] = getIdeal();
int update=0;

int max = 10000;
for (int i=0;i<max;i++) {
for (int j=0;j<xorData.length;j++) {
network.computeOutputs(xorData[j]);
network.calcError(xorIdeal[j]);
network.learn();
}


update++;
if (update==100) {
status.setText( "Cycles Left:" + (max-i) + ",Error:" + network.getError(xorData.length) );
update=0;
}
}
btnRun.setEnabled(true);
}


/**
* Called to generate an array of doubles based on
* the training data that the user has entered.
*
* @return An array of doubles
*/
double [][]getGrid()
{
double array[][] = new double[4][2];

for ( int i=0;i<4;i++ ) {
array[i][0] =
Float.parseFloat(data[i][0].getText());
array[i][1] =
Float.parseFloat(data[i][1].getText());
}

return array;
}

/**
* Called to the the ideal values that that the neural network
* should return for each of the grid training values.
*
* @return The ideal results.
*/
double [][]getIdeal()
{
double array[][] = new double[4][1];

for ( int i=0;i<4;i++ ) {
array[i][0] =
Float.parseFloat(data[i][2].getText());
}

return array;
}


}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值