神经网络,多输入多输出
反向传播使用到链式求导,公式如下:
单个神经元类
/**
* 单个神经元
*
* @author SXC 2020年8月13日 下午9:48:19
*/
public class Neurons {
ThreeNeurons tN;
double input[];
double inputandoutput[];
double weight[];// 包含一个偏置的权重,数量比输入大一个
private double nowoutput;
private double ez;// e误差对z中间输出的导数
private double ew[];// e误差对w权重的导数 同权重一起初始化
double expectation = 0;
double trainrate = 10;
double step = 0.00001;// 步进大小
double dw[];// 下次更新增量
// 输入信号数
Neurons(int inputcount) {
input = new double[inputcount];
weight = new double[inputcount + 1];
initweight();
}
public double[] getInput() {
return input;
}
public double[] getEw() {
return ew;
}
public void setEw(double[] ew) {
this.ew = ew;
}
public void setEz(double ez) {
this.ez = ez;
}
public double getEz() {
return ez;
}
public void settN(ThreeNeurons tN) {
this.tN = tN;
}
public double getNowoutput() {
return nowoutput;
}
// 设置期望值
public void setExpectation(double expectation) {
this.expectation = expectation;
}
// 计算误差值
public double errate() {
return Math.abs(calaOutput() - this.expectation);
}
// 计算模型误差值
public double merrate() {
return Math.abs(tN.calaOutput() - this.expectation);
}
// 设置权重值
public void setWeight(double[] weight) {
this.weight = weight;
if (ew == null) {
ew = new double[weight.length];
}
}
// 初始化权重
private void initweight() {
for (int i = 0; i < weight.length; i++) {
weight[i] = Math.random() * 2 - 1;
}
if (ew == null) {
ew = new double[weight.length];
}
}
// 获得权重
public double[] getWeight() {
return weight;
}
// --------------------输出打印---------------------------------------------
public void getEwtostring() {
String string = "当前ew为:[ ";
int i;
for (i = 0; i < ew.length; i++) {
string += ew[i] + " ";
}
System.out.println("输入数:" + i + string + "] ");
}
public void getEztostring() {
System.out.println("当前ez为:[ " + ez + " ]");
}
public void getinputtostring() {
String string = "当前输入为:[ ";
int i;
for (i = 0; i < input.length; i++) {
string += input[i] + " ";
}
System.out.println("输入数:" + i + string + "] ");
}
public void getoutnputtostring() {
System.out.println("该神经元输出:" + calaOutput());
}
public void getweighttostring() {
String string = "当前权重为:[ ";
int i;
for (i = 0; i < weight.length; i++) {
string += weight[i] + " ";
}
System.out.println("权重数:" + i + string + "] ");
}
// --------------------输出打印---------------------------------------------
// 设置输入
public void setInput(double[] input) {
this.input = input;
}
// 设置输入输出
public void setInputandOutput(double[] inputandoutput) {
for (int i = 0; i < inputandoutput.length - 1; i++) {
this.input[i] = inputandoutput[i];
}
setExpectation(inputandoutput[inputandoutput.length - 1]);
}
// 神经元输出 0-1 sigmod(x1w1+x2w2+x3w3+bw4)
public double calaOutput() {
double output = 0;
for (int i = 0; i < input.length; i++) {
output = input[i] * weight[i] + output;
}
output = output + 1 * weight[weight.length - 1];
output = sigmoid(output);
nowoutput = output;
return output;
}
// 未经激励函数输出 x1w1+x2w2+x3w3+bw4
public double total() {
double output = 0;
for (int i = 0; i < input.length; i++) {
output = input[i] * weight[i] + output;
}
output = output + 1 * weight[weight.length - 1];
return output;
}
//----------------------调整权重---------------------------------------
// 调整权重1次
public void train() {
setDw();
updateweight();
}
// 计算权重增量
public void setDw() {
dw = new double[weight.length];
double errate = errate();
double newerrate = 0;
for (int i = 0; i < weight.length; i++) {
weight[i] += step;
newerrate = errate();
dw[i] = (newerrate - errate) / step;
weight[i] -= step;
}
}
public void msetDw() {
dw = new double[weight.length];
double errate = merrate();
double newerrate = 0;
for (int i = 0; i < weight.length; i++) {
weight[i] += step;
newerrate = merrate();
dw[i] = (newerrate - errate) / step;
weight[i] -= step;
}
}
// 更新权重
public void updateweight() {
for (int i = 0; i < dw.length; i++) {
weight[i] -= dw[i] * trainrate;
}
}
public double getTrainrate() {
return trainrate;
}
// 调整权重time次
public void train(int time) {
for (int i = 0; i < time; i++) {
train();
}
}
// 调整权重直到指定误差值
public void train(double d) {
while (errate() > d) {
train();
}
}
//----------------------调整权重---------------------------------------
// 激励函数
public double sigmoid(double input) {
return 1f / (1f + Math.pow(Math.E, -1 * input));
}
}
神经网络层类
/**
* 单个神经元
*
* @author SXC 2020年8月14日 下午9:21:20
*/
public class NetWork {
private int NeuronsC[];// 每个隐藏层神经元个数,new时初始化
private double nowoutput[];
Neurons AllNeurons[];
double input[];// 实时输入输出 input new时初始化
double expectation[];// 设定值
double trainrate = 1;// 学习率
NetWork(int m[], int in, int out) {// 隐藏层||...||隐藏层 共m层,每层m[i]个
input = new double[in];
NeuronsC = m;
int n = 0;
if (m[m.length-1]!=out) {
System.out.println("数据输入存在问题!!");
}
for (int i = 0; i < m.length; i++) {
n += m[i];
}
AllNeurons = new Neurons[n];
int index = 0;
for (int i = 0; i < m.length; i++) {
for (int j = 0; j < m[i]; j++) {
if (i == 0) {
AllNeurons[index] = new Neurons(in);//
index++;
} else {
AllNeurons[index] = new Neurons(m[i - 1]);
index++;
}
}
}
nowoutput = new double[out];
expectation = new double[out];
System.out.println("生成" + n + "个神经元");
}
public void setInput(double[] input) {
this.input = input;
}
public void setInputandLable(double[] inputandlable) {
for (int i = 0; i < input.length ; i++) {
input[i] = inputandlable[i];
}
for (int i = 0; i < expectation.length; i++) {
this.expectation[i] = inputandlable[i + input.length];
}
}
public void setExpectation(double expectation) {
this.expectation[0] = expectation;
}
public double err() {
double err=0;
for (int i = 0; i < expectation.length; i++) {
err+=(expectation[i] - forward_pass(i));
}
return err;
}
// 前向传播
public double forward_pass(int p) {
// 逐层传播
double newinput[] = null;// 除第一列外的输入数据
for (int i = 0; i < NeuronsC.length; i++) {// 一列一列来
for (int j = 0; j < NeuronsC[i]; j++) {
if (i == 0) {
getNeurons(1, j + 1).setInput(input);// 第一列用外部输入的输入数据
} else {
if (j == 0) {// 每列更新一次输入数据
newinput = new double[NeuronsC[i - 1]];
for (int k = 0; k < NeuronsC[i - 1]; k++) {
newinput[k] = getNeurons(i, k + 1).calaOutput();
}
}
getNeurons(i + 1, j + 1).setInput(newinput);// 除一列外输入的输入数据使用上一列的输出
}
}
}
for (int i = 0; i < nowoutput.length; i++) {
nowoutput[i] = getNeurons(NeuronsC.length, i + 1).calaOutput();
}
return nowoutput[p];// 输出最后一列第一个神经元的输出
}
// 反向传播 更新权重
public void update_weights() {
double[] nowoutput = getNowoutput();
// 从输出层开始 遍历各层
for (int i = NeuronsC.length; i > 0; i--) {
if (i == NeuronsC.length) {// 输出层特殊计算 暂时设定为一个神经元
for (int g = 0; g < NeuronsC[i - 1]; g++) {
Neurons Neurons = getNeurons(NeuronsC.length, g+1);
Neurons.setEz((nowoutput[g] - this.expectation[g]) * active_derivative(Neurons.getNowoutput()));
double ew[] = new double[Neurons.getWeight().length];
for (int j = 0; j < ew.length - 1; j++) {// 遍历各个权重
ew[j] = Neurons.getEz() * Neurons.getInput()[j];
}
ew[ew.length - 1] = Neurons.getEz();
Neurons.setEw(ew);
}
} else {// 计算除输出层外的ew
for (int j = 1; j < NeuronsC[i - 1] + 1; j++) {// 遍历各个该层各个神经元 i列j个神经元
Neurons Neurons = getNeurons(i, j);
// 计算ez=上一层ez*对应到上一层的权重w+....+上一层ez*对应到上一层的权重w
double ea = 0;// e对输出的导数
for (int k = 1; k < NeuronsC[i] + 1; k++) {
ea += getNeurons(i + 1, k).getEz() * getNeurons(i + 1, k).getWeight()[j - 1];
}
Neurons.setEz(ea * active_derivative(Neurons.getNowoutput()));
double ew[] = new double[Neurons.getWeight().length];
for (int l = 0; l < ew.length - 1; l++) {// 遍历各个权重
ew[l] = Neurons.getEz() * Neurons.getInput()[l];
}
ew[ew.length - 1] = Neurons.getEz();
Neurons.setEw(ew);
}
}
}
// 开始更新
for (Neurons neurons : AllNeurons) {
for (int i = 0; i < neurons.getWeight().length; i++) {
neurons.getWeight()[i] -= neurons.getEw()[i] * trainrate;
}
}
}
public double[] getNowoutput() {
forward_pass(0);
return nowoutput;
}
public double getNowoutput(int i) {
return nowoutput[i];
}
// --------------------输出打印---------------------------------------------
public void getweighttostring() {
for (Neurons neurons : AllNeurons) {
neurons.getweighttostring();
}
}
public void getEwtostring() {
for (Neurons neurons : AllNeurons) {
neurons.getEwtostring();
}
}
public void getEztostring() {
for (Neurons neurons : AllNeurons) {
neurons.getEztostring();
}
}
public void getinputtostring() {
for (Neurons neurons : AllNeurons) {
neurons.getinputtostring();
}
}
public void getoutnputtostring() {
for (Neurons neurons : AllNeurons) {
neurons.getoutnputtostring();
}
}
// 激励函数
public double active(double input) {
return 1f / (1f + Math.pow(Math.E, -1 * input));
}
// 激励函数
public double active_derivative(double input) {
return input * (1 - input);
}
// --------------------输出打印---------------------------------------------
// 返回col列,cow行的神经元,都从1开始
public Neurons getNeurons(int col, int row) {
if (row > NeuronsC[col - 1]) {
System.out.println("该层没有这么多神经元!!请求" + col + "列" + row + "行神经元");
} else {
int n = 0;
for (int i = 0; i < col - 1; i++) {
n += NeuronsC[i];
}
n += row - 1;
// System.out.println("请求"+col+"列"+row+"行神经元");
return AllNeurons[n];
}
return null;
}
}
main类
public class t {
public static void main(String[] args) {
int a[] = { 5, 5, 4, 7, 9, 2 };
double b[] = { 1, 1 };
double data0[] = { 1, 1, 1, 1 };
double data1[] = { 1, 0, 1, 0 };
double data2[] = { 0, 1, 1, 0 };
double data3[] = { 0, 0, 0, 1 };
// double data4[] = { 1, 0, 0, 0 };
// double data5[] = { 0, 1, 0, 0 };
// double data6[] = { 0, 0, 1, 0 };
// double data7[] = { 0, 0, 0, 0 };
NetWork NetWork = new NetWork(a, 2, 2);
double err = 1;
int i;
for (i = 0; err >= 0.08; i++) {
err = 0;
NetWork.setInputandLable(data0);
NetWork.update_weights();
err += Math.abs(NetWork.err());
NetWork.setInputandLable(data1);
NetWork.update_weights();
err += Math.abs(NetWork.err());
NetWork.setInputandLable(data2);
NetWork.update_weights();
err += Math.abs(NetWork.err());
NetWork.setInputandLable(data3);
NetWork.update_weights();
err += Math.abs(NetWork.err());
// NetWork.setInputandLable(data4);
// NetWork.update_weights();
//
// NetWork.setInputandLable(data5);
// NetWork.update_weights();
//
// NetWork.setInputandLable(data6);
// NetWork.update_weights();
//
// NetWork.setInputandLable(data7);
// NetWork.update_weights();
}
System.out.println("误差:" + err);
System.out.println("运行次数:" + i);
NetWork.setInputandLable(data0);
err += Math.abs(NetWork.err());
System.out.println(NetWork.getNowoutput(0) + " " + NetWork.getNowoutput(1));
NetWork.setInputandLable(data1);
err += Math.abs(NetWork.err());
System.out.println(NetWork.getNowoutput(0) + " " + NetWork.getNowoutput(1));
NetWork.setInputandLable(data2);
err += Math.abs(NetWork.err());
System.out.println(NetWork.getNowoutput(0) + " " + NetWork.getNowoutput(1));
NetWork.setInputandLable(data3);
err += Math.abs(NetWork.err());
System.out.println(NetWork.getNowoutput(0) + " " + NetWork.getNowoutput(1));
// NetWork.setInputandLable(data4);
// err += Math.abs(NetWork.err());
// System.out.println(NetWork.getNowoutput()[0]);
// NetWork.setInputandLable(data5);
// err += Math.abs(NetWork.err());
// System.out.println(NetWork.getNowoutput()[0]);
// NetWork.setInputandLable(data6);
// err += Math.abs(NetWork.err());
// System.out.println(NetWork.getNowoutput()[0]);
// NetWork.setInputandLable(data7);
// err += Math.abs(NetWork.err());
// System.out.println(NetWork.getNowoutput()[0]);
System.out.println("误差:" + err);
// NetWork.getweighttostring();
drawnet(a);
}
private static void drawnet(int[] a) {
System.out.println("----------------------------------------------------------------------------------------------");
System.out.println("神经网络图:");
int max = 0;
for (int i = 0; i < a.length; i++) {
if (max < a[i]) {
max = a[i];
}
}
for (int i = 0; i < max; i++) {// 行
String string = "";
for (int j = 0; j < a.length; j++) {// 列
if (i>=(max-a[j])/2&&i<(max+a[j])/2) {
string += " ● ";
} else {
string += " ";
}
}
System.out.println(string);
}
System.out.println("----------------------------------------------------------------------------------------------");
}
}
输出结果:
生成32个神经元
误差:0.079968261269626
运行次数:31782
0.9999784568902769 0.9746964043839443
0.9915416453943027 0.026059584935733092
0.9915361463055753 0.026071465598939452
0.01106848775432035 0.9694260287169286
误差:0.16000772575870634
----------------------------------------------------------------------------------------------
神经网络图:
●
● ●
● ● ● ● ●
● ● ● ● ● ●
● ● ● ● ● ●
● ● ● ● ●
● ● ● ●
● ●
●
----------------------------------------------------------------------------------------------