/** * Network * Copyright 2005 by Jeff Heaton(jeff@jeffheaton.com) * * Example program from Chapter 3 * Programming Neural Networks in Java * http://www.heatonresearch.com/articles/series/1/ * * This software is copyrighted. You may use it in programs * of your own, without restriction, but you may not * publish the source code without the author's permission. * For more information on distributing this code, please * visit: * http://www.heatonresearch.com/hr_legal.php * * @author Jeff Heaton * @version 1.1 */
public class Network {
/** * The global error for the training. */ protected double globalError;
/** * The number of input neurons. */ protected int inputCount;
/** * The number of hidden neurons. */ protected int hiddenCount;
/** * The number of output neurons */ protected int outputCount;
/** * The total number of neurons in the network. */ protected int neuronCount;
/** * The number of weights in the network. */ protected int weightCount;
/** * The learning rate. */ protected double learnRate;
/** * The outputs from the various levels. */ protected double fire[];
/** * The weight matrix this, along with the thresholds can be * thought of as the "memory" of the neural network. */ protected double matrix[];
/** * The errors from the last calculation. */ protected double error[];
/** * The thresholds, this value, along with the weight matrix * can be thought of as the memory of the neural network. */ protected double thresholds[];
/** * The changes that should be applied to the weight * matrix. */ protected double matrixDelta[];
/** * The accumulation of the threshold deltas. */ protected double accThresholdDelta[];
/** * The threshold deltas. */ protected double thresholdDelta[];
/** * The momentum for training. */ protected double momentum;
/** * The changes in the errors. */ protected double errorDelta[];
/** * Construct the neural network. * * @param inputCount The number of input neurons. * @param hiddenCount The number of hidden neurons * @param outputCount The number of output neurons * @param learnRate The learning rate to be used when training. * @param momentum The momentum to be used when training. */ public Network(int inputCount, int hiddenCount, int outputCount, double learnRate, double momentum) {
fire = new double[neuronCount]; matrix = new double[weightCount]; matrixDelta = new double[weightCount]; thresholds = new double[neuronCount]; errorDelta = new double[neuronCount]; error = new double[neuronCount]; accThresholdDelta = new double[neuronCount]; accMatrixDelta = new double[weightCount]; thresholdDelta = new double[neuronCount];
reset(); }
/** * Returns the root mean square error for a complet training set. * * @param len The length of a complete training set. * @return The current error for the neural network. */ public double getError(int len) { double err = Math.sqrt(globalError / (len * outputCount)); globalError = 0; // clear the accumulator return err;
}
/** * The threshold method. You may wish to override this class to provide other * threshold methods. * * @param sum The activation from the neuron. * @return The activation applied to the threshold method. */ public double threshold(double sum) { return 1.0 / (1 + Math.exp(-1.0 * sum)); }
/** * Compute the output for a given input to the neural network. * * @param input The input provide to the neural network. * @return The results from the output neurons. */ public double []computeOutputs(double input[]) { int i, j; final int hiddenIndex = inputCount; final int outIndex = inputCount + hiddenCount;
for (i = 0; i < inputCount; i++) { fire[i] = input[i]; }
// first layer int inx = 0;
for (i = hiddenIndex; i < outIndex; i++) { double sum = thresholds[i];
for (j = 0; j < inputCount; j++) { sum += fire[j] * matrix[inx++]; } fire[i] = threshold(sum); }
// hidden layer
double result[] = new double[outputCount];
for (i = outIndex; i < neuronCount; i++) { double sum = thresholds[i];
/** * Calculate the error for the recogntion just done. * * @param ideal What the output neurons should have yielded. */ public void calcError(double ideal[]) { int i, j; final int hiddenIndex = inputCount; final int outputIndex = inputCount + hiddenCount;
// clear hidden layer errors for (i = inputCount; i < neuronCount; i++) { error[i] = 0; }
// layer errors and deltas for output layer for (i = outputIndex; i < neuronCount; i++) { error[i] = ideal[i - outputIndex] - fire[i]; globalError += error[i] * error[i]; errorDelta[i] = error[i] * fire[i] * (1 - fire[i]); }
// hidden layer errors int winx = inputCount * hiddenCount;
for (i = outputIndex; i < neuronCount; i++) { for (j = hiddenIndex; j < outputIndex; j++) { accMatrixDelta[winx] += errorDelta[i] * fire[j]; error[j] += matrix[winx] * errorDelta[i]; winx++; } accThresholdDelta[i] += errorDelta[i]; }
// hidden layer deltas for (i = hiddenIndex; i < outputIndex; i++) { errorDelta[i] = error[i] * fire[i] * (1 - fire[i]); }
// input layer errors winx = 0; // offset into weight array for (i = hiddenIndex; i < outputIndex; i++) { for (j = 0; j < hiddenIndex; j++) { accMatrixDelta[winx] += errorDelta[i] * fire[j]; error[j] += matrix[winx] * errorDelta[i]; winx++; } accThresholdDelta[i] += errorDelta[i]; } }
/** * Modify the weight matrix and thresholds based on the last call to * calcError. */ public void learn() { int i;
// process the matrix for (i = 0; i < matrix.length; i++) { matrixDelta[i] = (learnRate * accMatrixDelta[i]) + (momentum * matrixDelta[i]); matrix[i] += matrixDelta[i]; accMatrixDelta[i] = 0; }
// process the thresholds for (i = inputCount; i < neuronCount; i++) { thresholdDelta[i] = learnRate * accThresholdDelta[i] + (momentum * thresholdDelta[i]); thresholds[i] += thresholdDelta[i]; accThresholdDelta[i] = 0; } }
/** * Reset the weight matrix and the thresholds. */ public void reset() { int i;
for (i = 0; i < neuronCount; i++) { thresholds[i] = 0.5 - (Math.random()); thresholdDelta[i] = 0; accThresholdDelta[i] = 0; } for (i = 0; i < matrix.length; i++) { matrix[i] = 0.5 - (Math.random()); matrixDelta[i] = 0; accMatrixDelta[i] = 0; } } }
/** * XorExample * Copyright 2005 by Jeff Heaton(jeff@jeffheaton.com) * * Example program from Chapter 3 * Programming Neural Networks in Java * http://www.heatonresearch.com/articles/series/1/ * * This software is copyrighted. You may use it in programs * of your own, without restriction, but you may not * publish the source code without the author's permission. * For more information on distributing this code, please * visit: * http://www.heatonresearch.com/hr_legal.php * * @author Jeff Heaton * @version 1.1 */ public class XorExample extends JFrame implements ActionListener,Runnable {
/** * The number of input neurons. */ protected final static int NUM_INPUT = 2;
/** * The number of output neurons. */ protected final static int NUM_OUTPUT = 1;
/** * The number of hidden neurons. */ protected final static int NUM_HIDDEN = 3;
/** * The learning rate. */ protected final static double RATE = 0.5;
/** * The learning momentum. */ protected final static double MOMENTUM = 0.7;
/** * The training data that the user enters. * This represents the inputs and expected * outputs for the XOR problem. */ protected JTextField data[][] = new JTextField[4][4];
/** * The neural network. */ protected Network network;
/** * Constructor. Setup the components. */ public XorExample() { setTitle("XOR Solution"); network = new Network( NUM_INPUT, NUM_HIDDEN, NUM_OUTPUT, RATE, MOMENTUM);
Container content = getContentPane();
GridBagLayout gridbag = new GridBagLayout(); GridBagConstraints c = new GridBagConstraints(); content.setLayout(gridbag);
// Training input label c.gridwidth = GridBagConstraints.REMAINDER; //end row c.anchor = GridBagConstraints.NORTHWEST; content.add( new JLabel( "Enter training data:"),c);
JPanel grid = new JPanel(); grid.setLayout(new GridLayout(5,4)); grid.add(new JLabel("IN1")); grid.add(new JLabel("IN2")); grid.add(new JLabel("Expected OUT ")); grid.add(new JLabel("Actual OUT"));
for ( int i=0;i<4;i++ ) { int x = (i&1); int y = (i&2)>>1; grid.add(data[i][0] = new JTextField(""+y)); grid.add(data[i][1] = new JTextField(""+x)); grid.add(data[i][2] = new JTextField(""+(x^y))); grid.add(data[i][3] = new JTextField("??")); data[i][0].setEditable(false); data[i][1].setEditable(false); data[i][3].setEditable(false); }
content.add(grid,c);
// the button panel JPanel buttonPanel = new JPanel(new FlowLayout()); buttonPanel.add(btnTrain = new JButton("Train")); buttonPanel.add(btnRun = new JButton("Run")); buttonPanel.add(btnQuit = new JButton("Quit")); btnTrain.addActionListener(this); btnRun.addActionListener(this); btnQuit.addActionListener(this);
// Training input label c.gridwidth = GridBagConstraints.REMAINDER; //end row c.anchor = GridBagConstraints.NORTHWEST; content.add( status = new JLabel("Click train to begin training..."),c);
// adjust size and position pack(); Toolkit toolkit = Toolkit.getDefaultToolkit(); Dimension d = toolkit.getScreenSize(); setLocation( (int)(d.width-this.getSize().getWidth())/2, (int)(d.height-this.getSize().getHeight())/2 ); setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); setResizable(false);
btnRun.setEnabled(false); }
/** * The main function, just display the JFrame. * * @param args No arguments are used. */ public static void main(String args[]) { (new XorExample()).show(true); }
/** * Called when the user clicks one of the three * buttons. * * @param e The event. */ public void actionPerformed(ActionEvent e) { if ( e.getSource()==btnQuit ) System.exit(0); else if ( e.getSource()==btnTrain ) train(); else if ( e.getSource()==btnRun ) evaluate(); }
/** * Called when the user clicks the run button. */ protected void evaluate() { double xorData[][] = getGrid(); int update=0;
/** * Called when the user clicks the train button. */ protected void train() { if ( worker != null ) worker = null; worker = new Thread(this); worker.setPriority(Thread.MIN_PRIORITY); worker.start(); }
/** * The thread worker, used for training */ public void run() { double xorData[][] = getGrid(); double xorIdeal[][] = getIdeal(); int update=0;
int max = 10000; for (int i=0;i<max;i++) { for (int j=0;j<xorData.length;j++) { network.computeOutputs(xorData[j]); network.calcError(xorIdeal[j]); network.learn(); }
/** * Called to generate an array of doubles based on * the training data that the user has entered. * * @return An array of doubles */ double [][]getGrid() { double array[][] = new double[4][2];
for ( int i=0;i<4;i++ ) { array[i][0] = Float.parseFloat(data[i][0].getText()); array[i][1] = Float.parseFloat(data[i][1].getText()); }
return array; }
/** * Called to the the ideal values that that the neural network * should return for each of the grid training values. * * @return The ideal results. */ double [][]getIdeal() { double array[][] = new double[4][1];
for ( int i=0;i<4;i++ ) { array[i][0] = Float.parseFloat(data[i][2].getText()); }
[code="java"]/** * Network * Copyright 2005 by Jeff Heaton(jeff@jeffheaton.com) * * Example program from Chapter 3 * Programming Neural Networks in Java * http://www.heatonresearch.com/a...