做了一个神经元网络分类器。开始步长设置为迭代次数的倒数,效果不好;后来调整到 0.2 效果比较好。测试一个抛物线边界的例子,准确率大约 96% 以上。
public final class NeuroNetwork {
private static class Neurode {
double err;
double output;
double theta;
}
private static enum Status {
NEW,
TRAINED;
}
// status of this class, either NEW or TRAINED
private Status status;
// depth of network, layer 0 is input layer
private int depth;
// neurodes in each layer
private Neurode[][] neurodes;
// weights[i] is a two dimensional array, representing weights between layer i and layer 1+1
private double[][][] weights;
// initialize the neuronetwork
/**
* Initialize the neuronetwork
*
* @param depth : the number of layers
* @param numNeurodes : the number of neurodes in each layer
*/
public NeuroNetwork(int depth, int[] numNeurodes) {
this.depth = depth;
// create and initialize neurodes
neurodes = new Neurode[depth][];
for ( int d=0; d<depth; d++ ) {
neurodes[d] = new Neurode[numNeurodes[d]];
for ( int i=0; i<numNeurodes[d]; i++) {
neurodes[d][i] = new Neurode();
neurodes[d][i].theta = Math.random();
}
}
// initialize weights
weights = new double[depth][][];
for ( int d=0; d<depth-1; d++ ) {
weights[d] = new double[numNeurodes[d]][numNeurodes[d+1]];
for ( int i=0; i<numNeurodes[d]; i++) {
for ( int j=0; j<numNeurodes[d+1]; j++ ) {
weights[d][i][j] = Math.random();
}
}
}
status = Status.NEW;
}
/**
* Calculate output given a input
*
* @param data : an vector representing input
*/
private void calculateOutput(double[] data) {
// initial output of layer 0
for (int i=0; i<neurodes[0].length; i++ ) {
neurodes[0][i].output = data[i];
}
// calculate output for each output layer
for ( int d=1; d<depth; d++ ) {
for ( int j=0; j<neurodes[d].length; j++) {
double input = 0.0;
for ( int i=0; i<neurodes[d-1].length; i++ ) {
input += neurodes[d-1][i].output*weights[d-1][i][j];
}
input += neurodes[d][j].theta;
neurodes[d][j].output = 1.0/(1.0+Math.exp(0.0-input));
}
}
}
/**
* Classify and predict
*
* @param data : an vector represent one entry of taining sample
* @param target : an vector represent class label of the training sample
*/
public int predict(double[] data, double[] output) {
if ( data.length != neurodes[0].length || output.length != neurodes[depth-1].length ) {
throw new IllegalArgumentException();
}
calculateOutput(data);
double x = neurodes[depth-1][0].output;
int label = 0;
for ( int i=0; i<neurodes[depth-1].length; i++ ) {
output[i] = neurodes[depth-1][i].output;
if ( x < output[i] ) {
x = output[i];
label = i;
}
}
return label;
}
/**
* Train the neuronetwork
*
* @param data : input matrix of train data, with data[i] represents the ith sample
* @param target : input matrix of train label, with target[i] represents the ith label
* @param maxIteration : maximum times of interation
* @param threshold : threshold of weights update
* @param errorRate : threshold for error rate
* @return
*/
public boolean train(double[][] data, double target[][], int maxIteration, double threshold, double errorRate) {
// check status
if ( status == Status.TRAINED ){
throw new IllegalStateException();
}
// check input arguments and input parameters
if ( data.length <=0 || data[0].length != neurodes[0].length ||
target.length == 0 || target[0].length != neurodes[depth-1].length ) {
throw new IllegalArgumentException();
}
int round = 1;
boolean convergence = false;
while ( round <= maxIteration && ! convergence ) {
double rate = 0.2;//1.0/round; // learn rate
double delta = 0.0;
for ( int r=0; r<data.length; r++) {
double res = trainWithOneSample(data[r], target[r], rate);
delta = (delta<res)?res:delta;
}
convergence = (delta<threshold);
round++;
System.out.printf(" %d round of train, delta is %f %n", round-1, delta);
}
return true;
}
/**
* Train the neuronetwork with one entry of sample data
*
* @param data : an vector represent one entry of taining sample
* @param target : an vector represent class label of the training sample
* @param rate : learn rate
* @return : maximum detla of weights
*/
private double trainWithOneSample(double[] data, double[] target, double rate) {
calculateOutput(data);
// calculate error for layer n-1
for ( int j=0; j<neurodes[depth-1].length; j++ ) {
double output = neurodes[depth-1][j].output;
neurodes[depth-1][j].err = output*(1-output)*(target[j]-output);
}
// calculate error for hidden layers n-2 ... 1
for ( int d=depth-2; d>0; d-- ) {
for ( int j=0; j<neurodes[d].length; j++ ) {
double error = 0.0;
for ( int k=0; k<neurodes[d+1].length; k++ ) {
error += neurodes[d+1][k].err*weights[d][j][k];
}
double output = neurodes[d][j].output;
neurodes[d][j].err = output*(1-output)*error;
}
}
double maxDelta = 0.0;
// update weights
for ( int d=0; d<depth-1; d++ ) {
for ( int i=0; i<neurodes[d].length; i++ ) {
for ( int j=0; j<neurodes[d+1].length; j++ ) {
double delta = neurodes[d][i].output*neurodes[d+1][j].err;
weights[d][i][j] += rate*delta;
if ( maxDelta < Math.abs(delta) ) {
maxDelta = Math.abs(delta);
}
}
}
}
// update theta
for ( int d=1; d<depth; d++ ) {
for ( int j=0; j<neurodes[d].length; j++ ) {
neurodes[d][j].theta += rate*neurodes[d][j].err;
}
}
return maxDelta;
}
}
测试:
public class TestMain {
public static double[][][] generateData(int m) {
double[][][] res = new double[2][][];
double[][] data = new double[m*m][2];
double[][] label = new double[m*m][3];
for ( int i=0; i<m; i++ ) {
double x = i/(m-1.0);
for ( int j=0; j<m; j++ ) {
double y = j/(m-1.0);
data[i*m+j][0] = x;
data[i*m+j][1] = y;
label[i*m+j][0] = label[i*m+j][1] = label[i*m+j][2] = 0;
if ( y > 4.0*(x-0.5)*(x-0.5) ) {
label[i*m+j][0] = 1;
} else if ( x < 0.5 ) {
label[i*m+j][1] = 1;
} else {
label[i*m+j][2] = 1;
}
}
}
res[0] = data;
res[1] = label;
return res;
}
public static int calculateLabel(double x, double y) {
if ( y > 4.0*(x-0.5)*(x-0.5) ) {
return 0;
} else if ( x < 0.5 ) {
return 1;
} else {
return 2;
}
}
/**
* @param args
*/
public static void main(String[] args) {
int[] num = { 2, 3, 3 };
int m = 10, n = 3;
NeuroNetwork inst = new NeuroNetwork(num.length, num);
double[][][] trainData = generateData(m);
inst.train(trainData[0], trainData[1], 1000000, 0.001, 0.8);
int t=50, success = 0;
double[][][] testData = generateData(t);
for ( int i=0; i<t*t; i++ ) {
int res = inst.predict(testData[0][i], testData[1][i]);
int ans = calculateLabel(testData[0][i][0], testData[0][i][1]);
if ( res == ans ) {
success ++;
}
System.out.printf("<%f, %f> : %d %b%n",testData[0][i][0],testData[0][i][1],res,res==ans);
}
System.out.printf("Accuracy rate is %f%n", (success+0.0)/(t*t));
}
}