学习来源:https://blog.csdn.net/minfanphd/article/details/116974889
一、通用BP神经网络 (1. 集中管理激活函数)
1、激活函数是神经网络的核心。
2、有很多的激活函数, 它们的设计有相应准则, 如分段可导。
3、查资料补充几个未实现的激活函数。
4、代码
package bp神经网络;
/**
* @time 2022/6/2
* @author Liang Huang
*/
public class Activator {
/**
* Arc tan.
*/
public final char ARC_TAN = 'a';
/**
* Elu.
*/
public final char ELU = 'e';
/**
* Gelu.
*/
public final char GELU = 'g';
/**
* Hard logistic.
*/
public final char HARD_LOGISTIC = 'h';
/**
* Identity.
*/
public final char IDENTITY = 'i';
/**
* Leaky relu, also known as parametric relu.
*/
public final char LEAKY_RELU = 'l';
/**
* Relu.
*/
public final char RELU = 'r';
/**
* Soft sign.
*/
public final char SOFT_SIGN = 'o';
/**
* Sigmoid.
*/
public final char SIGMOID = 's';
/**
* Tanh.
*/
public final char TANH = 't';
/**
* Soft plus.
*/
public final char SOFT_PLUS = 'u';
/**
* Swish.
*/
public final char SWISH = 'w';
/**
* The activator.
*/
private char activator;
/**
* Alpha for elu.
*/
double alpha;
/**
* Beta for leaky relu.
*/
double beta;
/**
* Gamma for leaky relu.
*/
double gamma;
/**
*********************
* The first constructor.
*
* @param paraActivator The activator.
*********************
*/
public Activator(char paraActivator) {
activator = paraActivator;
}// Of the first constructor
/**
*********************
* Setter.
*********************
*/
public void setActivator(char paraActivator) {
activator = paraActivator;
}// Of setActivator
/**
*********************
* Getter.
*********************
*/
public char getActivator() {
return activator;
}// Of getActivator
/**
*********************
* Setter.
*********************
*/
void setAlpha(double paraAlpha) {
alpha = paraAlpha;
}// Of setAlpha
/**
*********************
* Setter.
*********************
*/
void setBeta(double paraBeta) {
beta = paraBeta;
}// Of setBeta
/**
*********************
* Setter.
*********************
*/
void setGamma(double paraGamma) {
gamma = paraGamma;
}// Of setGamma
/**
*********************
* Activate according to the activation function.
*********************
*/
public double activate(double paraValue) {
double resultValue = 0;
switch (activator) {
case ARC_TAN:
resultValue = Math.atan(paraValue);
break;
case ELU:
if (paraValue >= 0) {
resultValue = paraValue;
} else {
resultValue = alpha * (Math.exp(paraValue) - 1);
} // Of if
break;
// case GELU:
// resultValue = ?;
// break;
// case HARD_LOGISTIC:
// resultValue = ?;
// break;
case IDENTITY:
resultValue = paraValue;
break;
case LEAKY_RELU:
if (paraValue >= 0) {
resultValue = paraValue;
} else {
resultValue = alpha * paraValue;
} // Of if
break;
case SOFT_SIGN:
if (paraValue >= 0) {
resultValue = paraValue / (1 + paraValue);
} else {
resultValue = paraValue / (1 - paraValue);
} // Of if
break;
case SOFT_PLUS:
resultValue = Math.log(1 + Math.exp(paraValue));
break;
case RELU:
if (paraValue >= 0) {
resultValue = paraValue;
} else {
resultValue = 0;
} // Of if
break;
case SIGMOID:
resultValue = 1 / (1 + Math.exp(-paraValue));
break;
case TANH:
resultValue = 2 / (1 + Math.exp(-2 * paraValue)) - 1;
break;
// case SWISH:
// resultValue = ?;
// break;
default:
System.out.println("Unsupported activator: " + activator);
System.exit(0);
}// Of switch
return resultValue;
}// Of activate
/**
*********************
* Derive according to the activation function. Some use x while others use
* f(x).
*
* @param paraValue The original value x.
* @param paraActivatedValue f(x).
*********************
*/
public double derive(double paraValue, double paraActivatedValue) {
double resultValue = 0;
switch (activator) {
case ARC_TAN:
resultValue = 1 / (paraValue * paraValue + 1);
break;
case ELU:
if (paraValue >= 0) {
resultValue = 1;
} else {
resultValue = alpha * (Math.exp(paraValue) - 1) + alpha;
} // Of if
break;
// case GELU:
// resultValue = ?;
// break;
// case HARD_LOGISTIC:
// resultValue = ?;
// break;
case IDENTITY:
resultValue = 1;
break;
case LEAKY_RELU:
if (paraValue >= 0) {
resultValue = 1;
} else {
resultValue = alpha;
} // Of if
break;
case SOFT_SIGN:
if (paraValue >= 0) {
resultValue = 1 / (1 + paraValue) / (1 + paraValue);
} else {
resultValue = 1 / (1 - paraValue) / (1 - paraValue);
} // Of if
break;
case SOFT_PLUS:
resultValue = 1 / (1 + Math.exp(-paraValue));
break;
case RELU: // Updated
if (paraValue >= 0) {
resultValue = 1;
} else {
resultValue = 0;
} // Of if
break;
case SIGMOID: // Updated
resultValue = paraActivatedValue * (1 - paraActivatedValue);
break;
case TANH: // Updated
resultValue = 1 - paraActivatedValue * paraActivatedValue;
break;
// case SWISH:
// resultValue = ?;
// break;
default:
System.out.println("Unsupported activator: " + activator);
System.exit(0);
}// Of switch
return resultValue;
}// Of derive
/**
*********************
* Overrides the method claimed in Object.
*********************
*/
public String toString() {
String resultString = "Activator with function '" + activator + "'";
resultString += "\r\n alpha = " + alpha + ", beta = " + beta + ", gamma = " + gamma;
return resultString;
}// Of toString
/**
********************
* Test the class.
********************
*/
public static void main(String[] args) {
Activator tempActivator = new Activator('s');
double tempValue = 0.6;
double tempNewValue;
tempNewValue = tempActivator.activate(tempValue);
System.out.println("After activation: " + tempNewValue);
tempNewValue = tempActivator.derive(tempValue, tempNewValue);
System.out.println("After derive: " + tempNewValue);
}// Of main
}// Of class Activator
二、通用BP神经网络 (2. 单层实现)
1、仅实现单层 ANN。
2、可以有自己的激活函数。
3、正向计算输出, 反向计算误差并调整权值。
4、代码
package bp神经网络;
/**
* @time 2022/6/4
* @author Liang Huang
*/
import java.util.Arrays;
import java.util.Random;
public class AnnLayer {
/**
* The number of input.
*/
int numInput;
/**
* The number of output.
*/
int numOutput;
/**
* The learning rate.
*/
double learningRate;
/**
* The mobp.
*/
double mobp;
/**
* The weight matrix.
*/
double[][] weights;
/**
* The delta weight matrix.
*/
double[][] deltaWeights;
/**
* Error on nodes.
*/
double[] errors;
/**
* The inputs.
*/
double[] input;
/**
* The outputs.
*/
double[] output;
/**
* The output after activate.
*/
double[] activatedOutput;
/**
* The inputs.
*/
Activator activator;
/**
* The inputs.
*/
Random random = new Random();
/**
*********************
* The first constructor.
*
* @param paraActivator The activator.
*********************
*/
public AnnLayer(int paraNumInput, int paraNumOutput, char paraActivator,
double paraLearningRate, double paraMobp) {
numInput = paraNumInput;
numOutput = paraNumOutput;
learningRate = paraLearningRate;
mobp = paraMobp;
weights = new double[numInput + 1][numOutput];
deltaWeights = new double[numInput + 1][numOutput];
for (int i = 0; i < numInput + 1; i++) {
for (int j = 0; j < numOutput; j++) {
weights[i][j] = random.nextDouble();
} // Of for j
} // Of for i
errors = new double[numInput];
input = new double[numInput];
output = new double[numOutput];
activatedOutput = new double[numOutput];
activator = new Activator(paraActivator);
}// Of the first constructor
/**
********************
* Set parameters for the activator.
*
* @param paraAlpha Alpha. Only valid for certain types.
* @param paraBeta Beta.
* @param paraAlpha Alpha.
********************
*/
public void setParameters(double paraAlpha, double paraBeta, double paraGamma) {
activator.setAlpha(paraAlpha);
activator.setBeta(paraBeta);
activator.setGamma(paraGamma);
}// Of setParameters
/**
********************
* Forward prediction.
*
* @param paraInput The input data of one instance.
* @return The data at the output end.
********************
*/
public double[] forward(double[] paraInput) {
//System.out.println("Ann layer forward " + Arrays.toString(paraInput));
// Copy data.
for (int i = 0; i < numInput; i++) {
input[i] = paraInput[i];
} // Of for i
// Calculate the weighted sum for each output.
for (int i = 0; i < numOutput; i++) {
output[i] = weights[numInput][i];
for (int j = 0; j < numInput; j++) {
output[i] += input[j] * weights[j][i];
} // Of for j
activatedOutput[i] = activator.activate(output[i]);
} // Of for i
return activatedOutput;
}// Of forward
/**
********************
* Back propagation and change the edge weights.
*
* @param paraTarget For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
********************
*/
public double[] backPropagation(double[] paraErrors) {
//Step 1. Adjust the errors.
for (int i = 0; i < paraErrors.length; i++) {
paraErrors[i] = activator.derive(output[i], activatedOutput[i]) * paraErrors[i];
}//Of for i
//Step 2. Compute current errors.
for (int i = 0; i < numInput; i++) {
errors[i] = 0;
for (int j = 0; j < numOutput; j++) {
errors[i] += paraErrors[j] * weights[i][j];
deltaWeights[i][j] = mobp * deltaWeights[i][j]
+ learningRate * paraErrors[j] * input[i];
weights[i][j] += deltaWeights[i][j];
} // Of for j
} // Of for i
for (int j = 0; j < numOutput; j++) {
deltaWeights[numInput][j] = mobp * deltaWeights[numInput][j] + learningRate * paraErrors[j];
weights[numInput][j] += deltaWeights[numInput][j];
} // Of for j
return errors;
}// Of backPropagation
/**
********************
* I am the last layer, set the errors.
*
* @param paraTarget For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
********************
*/
public double[] getLastLayerErrors(double[] paraTarget) {
double[] resultErrors = new double[numOutput];
for (int i = 0; i < numOutput; i++) {
resultErrors[i] = (paraTarget[i] - activatedOutput[i]);
} // Of for i
return resultErrors;
}// Of getLastLayerErrors
/**
********************
* Show me.
********************
*/
public String toString() {
String resultString = "";
resultString += "Activator: " + activator;
resultString += "\r\n weights = " + Arrays.deepToString(weights);
return resultString;
}// Of toString
/**
********************
* Unit test.
********************
*/
public static void unitTest() {
AnnLayer tempLayer = new AnnLayer(2, 3, 's', 0.01, 0.1);
double[] tempInput = { 1, 4 };
System.out.println(tempLayer);
double[] tempOutput = tempLayer.forward(tempInput);
System.out.println("Forward, the output is: " + Arrays.toString(tempOutput));
double[] tempError = tempLayer.backPropagation(tempOutput);
System.out.println("Back propagation, the error is: " + Arrays.toString(tempError));
}// Of unitTest
/**
********************
* Test the algorithm.
********************
*/
public static void main(String[] args) {
unitTest();
}// Of main
}// Of class AnnLayer
三、通用BP神经网络 (3. 综合测试)
1、自己尝试其它的激活函数
2、代码
package bp神经网络;
/**
* @time 2022/6/4
* @author Liang Huang
*/
import java.util.Arrays;
public class FullAnn extends GeneralAnn {
/**
* The layers.
*/
AnnLayer[] layers;
/**
********************
* The first constructor.
*
* @param paraFilename The arff filename.
* @param paraLayerNumNodes The number of nodes for each layer (may be different).
* @param paraLearningRate Learning rate.
* @param paraMobp Momentum coefficient.
* @param paraActivators The storing the activators of each layer.
********************
*/
public FullAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,
double paraMobp, String paraActivators) {
super(paraFilename, paraLayerNumNodes, paraLearningRate, paraMobp);
// Initialize layers.
layers = new AnnLayer[numLayers - 1];
for (int i = 0; i < layers.length; i++) {
layers[i] = new AnnLayer(layerNumNodes[i], layerNumNodes[i + 1], paraActivators.charAt(i), paraLearningRate,
paraMobp);
} // Of for i
}// Of the first constructor
/**
********************
* Forward prediction. This is just a stub and should be overwritten in the subclass.
*
* @param paraInput The input data of one instance.
* @return The data at the output end.
********************
*/
public double[] forward(double[] paraInput) {
double[] resultArray = paraInput;
for(int i = 0; i < numLayers - 1; i ++) {
resultArray = layers[i].forward(resultArray);
}//Of for i
return resultArray;
}// Of forward
/**
********************
* Back propagation. This is just a stub and should be overwritten in the subclass.
*
* @param paraTarget For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
*
********************
*/
public void backPropagation(double[] paraTarget) {
double[] tempErrors = layers[numLayers - 2].getLastLayerErrors(paraTarget);
for (int i = numLayers - 2; i >= 0; i--) {
tempErrors = layers[i].backPropagation(tempErrors);
}//Of for i
return;
}// Of backPropagation
/**
********************
* Show me.
********************
*/
public String toString() {
String resultString = "I am a full ANN with " + numLayers + " layers";
return resultString;
}// Of toString
/**
********************
* Test the algorithm.
********************
*/
public static void main(String[] args) {
int[] tempLayerNodes = { 4, 8, 8, 3 };
FullAnn tempNetwork = new FullAnn("D:/data/iris.arff", tempLayerNodes, 0.01,
0.6, "sss");
for (int round = 0; round < 5000; round++) {
tempNetwork.train();
} // Of for n
double tempAccuray = tempNetwork.test();
System.out.println("The accuracy is: " + tempAccuray);
System.out.println("FullAnn ends.");
}// Of main
}// Of class FullAnn