day75

本文介绍了一个名为`AnnLayer`的类,它实现了人工神经网络(ANN)的一层。该类包含输入和输出数量、学习率、动量等参数,并支持前向传播和反向传播计算。`AnnLayer`还允许设置激活函数参数,并提供了调整权重和偏置的方法。通过单元测试展示了如何使用这个类进行简单的前向预测和误差反向传播。
摘要由CSDN通过智能技术生成
package machinelearning.ann;

import java.util.Arrays;
import java.util.Random;
/**
 * ******************************************
 *
 * @author Michelle Min MitchelleMin@163.com
 * @date 2021-08-07
 * ******************************************
 */
public class AnnLayer {

    /**
     * The number of input.
     */
    int numInput;

    /**
     * The number of output.
     */
    int numOutput;

    /**
     * The learning rate.
     */
    double learningRate;

    /**
     * The mobp.
     */
    double mobp;

    /**
     * The weight matrix.
     */
    double[][] weights, deltaWeights;

    double[] offset, deltaOffset, errors;

    /**
     * The inputs.
     */
    double[] input;

    /**
     * The outputs.
     */
    double[] output;

    /**
     * The output after activate.
     */
    double[] activatedOutput;

    /**
     * The inputs.
     */
    Activator activator;

    /**
     * The inputs.
     */
    Random random = new Random();

    /**
     *********************
     * The first constructor.
     *
     * @param paraActivator
     *            The activator.
     *********************
     */
    public AnnLayer(int paraNumInput, int paraNumOutput, char paraActivator,
                    double paraLearningRate, double paraMobp) {
        numInput = paraNumInput;
        numOutput = paraNumOutput;
        learningRate = paraLearningRate;
        mobp = paraMobp;

        weights = new double[numInput + 1][numOutput];
        deltaWeights = new double[numInput + 1][numOutput];
        for (int i = 0; i < numInput + 1; i++) {
            for (int j = 0; j < numOutput; j++) {
                weights[i][j] = random.nextDouble();
            } //of for j
        } //of for i

        offset = new double[numOutput];
        deltaOffset = new double[numOutput];
        errors = new double[numInput];

        input = new double[numInput];
        output = new double[numOutput];
        activatedOutput = new double[numOutput];

        activator = new Activator(paraActivator);
    }//of the first constructor

    /**
     ********************
     * Set parameters for the activator.
     *
     * @param paraAlpha
     *            Alpha. Only valid for certain types.
     * @param paraBeta
     *            Beta.
     * @param paraAlpha
     *            Alpha.
     ********************
     */
    public void setParameters(double paraAlpha, double paraBeta, double paraGamma) {
        activator.setAlpha(paraAlpha);
        activator.setBeta(paraBeta);
        activator.setGamma(paraGamma);
    }//of setParameters

    /**
     ********************
     * Forward prediction.
     *
     * @param paraInput
     *            The input data of one instance.
     * @return The data at the output end.
     ********************
     */
    public double[] forward(double[] paraInput) {
        //System.out.println("Ann layer forward " + Arrays.toString(paraInput));
        // Copy data.
        for (int i = 0; i < numInput; i++) {
            input[i] = paraInput[i];
        }//of for i

        // Calculate the weighted sum for each output.
        for (int i = 0; i < numOutput; i++) {
            output[i] = weights[numInput][i];
            for (int j = 0; j < numInput; j++) {
                output[i] += input[j] * weights[j][i];
            }//of for j

            activatedOutput[i] = activator.activate(output[i]);
        }//of for i

        return activatedOutput;
    }//of forward

    /**
     ********************
     * Back propagation and change the edge weights.
     *
     * @param paraErrors
     ********************
     */
    public double[] backPropagation(double[] paraErrors) {
        //Step 1. Adjust the errors.
        for (int i = 0; i < paraErrors.length; i++) {
            paraErrors[i] = activator.derive(output[i], activatedOutput[i]) * paraErrors[i];
        }//of for i

        //Step 2. Compute current errors.
        for (int i = 0; i < numInput; i++) {
            errors[i] = 0;
            for (int j = 0; j < numOutput; j++) {
                errors[i] += paraErrors[j] * weights[i][j];
                deltaWeights[i][j] = mobp * deltaWeights[i][j]
                        + learningRate * paraErrors[j] * input[i];
                weights[i][j] += deltaWeights[i][j];

                if (i == numInput - 1) {
                    // Offset adjusting
                    deltaOffset[j] = mobp * deltaOffset[j] + learningRate * paraErrors[j];
                    offset[j] += deltaOffset[j];
                }//of if
            }//of for j
        }//of for i

        return errors;
    }//of backPropagation

    /**
     ********************
     * I am the last layer, set the errors.
     *
     * @param paraTarget
     *            For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
     ********************
     */
    public double[] getLastLayerErrors(double[] paraTarget) {
        double[] resultErrors = new double[numOutput];
        for (int i = 0; i < numOutput; i++) {
            resultErrors[i] = (paraTarget[i] - activatedOutput[i]);
        }//of for i

        return resultErrors;
    }//of getLastLayerErrors

    /**
     ********************
     * Show me.
     ********************
     */
    public String toString() {
        String resultString = "";
        resultString += "Activator: " + activator;
        resultString += "\r\n weights = " + Arrays.deepToString(weights);
        return resultString;
    }//of toString

    /**
     ********************
     * Unit test.
     ********************
     */
    public static void unitTest() {
        AnnLayer tempLayer = new AnnLayer(2, 3, 's', 0.01, 0.1);
        double[] tempInput = { 1, 4 };

        System.out.println(tempLayer);

        double[] tempOutput = tempLayer.forward(tempInput);
        System.out.println("Forward, the output is: " + Arrays.toString(tempOutput));

        double[] tempError = tempLayer.backPropagation(tempOutput);
        System.out.println("Back propagation, the error is: " + Arrays.toString(tempError));
    }//of unitTest

    /**
     ********************
     * Test the algorithm.
     ********************
     */
    public static void main(String[] args) {
        unitTest();
    }//of main
}//of class AnnLayer

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值