BP类 //封装bp算法
package arithmetic;
public class BP {
private double[] P;
private double[] T;
private double[][] W1;
private double[][] W2;
private int n_a0;
private int n_a1;
private int n_a2;
private double[] B1;
private double[] B2;
private double[] a1;
private double[] a2;
private double[] q;
private double[] db1;
private double[] db2;
private double[][] dw1;
private double[][] dw2;
private double e;
private double r;
private double e0;
public BP(double[][] W1, double[][] W2, double[] B1, double[] B2) {
this.W1 = W1;
this.W2 = W2;
this.B1 = B1;
this.B2 = B2;
n_a0 = W1[0].length;
n_a1 = W1.length;
n_a2 = W2.length;
init();
}
public void setP(double[] P) {
this.P = P;
}
public void setT(double[] T) {
this.T = T;
}
private void init() {
a1 = new double[n_a1];
a2 = new double[n_a2];
r = 0.4;
e0 = 0.02;
q = new double[n_a2];
db2 = new double[n_a2];
dw2 = new double[n_a2][n_a1];
db1 = new double[n_a1];
dw1 = new double[n_a1][n_a0];
}
public void calA1() {
double temp = 0;
for (int i = 0; i < n_a1; i++) {
for (int j = 0; j < n_a0; j++) {
temp += W1[i][j] * P[j];
}
temp += B1[i];
a1[i] = F.f1(temp);
}
}
public double[] getA1() {
return a1;
}
public double[] getA2() {
return a2;
}
public void calA2() {
double temp = 0;
for (int k = 0; k < n_a2; k++) {
for (int i = 0; i < n_a1; i++) {
temp += W2[k][i] * a1[i];
}
temp += B2[k];
a2[k] = F.f2(temp);
}
}
public void calE() {
e = 0;
for (int k = 0; k < n_a2; k++) {
double ek = T[k] - a2[k];
e += ek * ek;
e /= 2;
}
}
public void calDb2() {
for (int k = 0; k < n_a2; k++) {
q[k] = (T[k] - a2[k]) * F.f2_1(a2[k]);
db2[k] = q[k] * r;
}
}
public void calDw2() {
for (int k = 0; k < n_a2; k++) {
for (int i = 0; i < n_a1; i++) {
dw2[k][i] = db2[k] * a1[i];
}
}
}
public void calDb1() {
for (int i = 0; i < n_a1; i++) {
db1[i] = 0;
for (int k = 0; k < n_a2; k++) {
db1[i] += q[k] * W2[k][i];
}
db1[i] *= r * F.f1_1(a1[i]);
}
}
public void calDw1() {
for (int i = 0; i < n_a1; i++) {
for (int j = 0; j < n_a0; j++) {
dw1[i][j] = db1[i] * P[j];
}
}
}
public void changeDb2() {
for (int i = 0; i < n_a2; i++) {
B2[i] += db2[i];
}
}
public void changeDw2() {
for (int i = 0; i < n_a2; i++) {
for (int j = 0; j < n_a1; j++) {
W2[i][j] += dw2[i][j];
}
}
}
public void changeDb1() {
for (int i = 0; i < n_a1; i++) {
B1[i] += db1[i];
}
}
public void changeDw1() {
for (int i = 0; i < n_a1; i++) {
for (int j = 0; j < n_a0; j++) {
W1[i][j] += dw1[i][j];
}
}
}
public void train(double[][] P, double[][] T) {
while(true){
boolean isChange = false;
for (int n = 0; n < P.length; n++) {
setP(P[n]);
setT(T[n]);
this.calA1();
this.calA2();
this.calE();
if (e < e0)
continue;
this.calDb2();
this.calDw2();
this.calDb1();
this.calDw1();
this.changeDb2();
this.changeDw2();
this.changeDb1();
this.changeDw1();
isChange = true;
// break;
}
if (!isChange) {
System.out.println("train succeed");
break;
}
}
}
public double[] divide(double[] p, double[] t) {
setP(p);
setT(t);
this.calA1();
this.calA2();
return a2;
}
public double getE(){
return e;
}
}
F类 //封装神经元函数
package arithmetic;
public class F {
public static double f1(double x){
return 1/(1+Math.exp(-1*x));
}
public static double f1_1(double y){
return y*(1-y);
}
public static double f2(double x){
return x;
}
public static double f2_1(double y){
return 1;
}
}
Controller类 读入训练数据和测试数据,并创建BP实例进行训练测试
package arithmetic;
import java.io.*;
import java.util.ArrayList;
import javax.swing.JFrame;
public class Controler {
private double[][] p_test;
private double[][] t_test;
private double[][] p_train;
private double[][] t_train;
private BP bp;
private JFrame viwer;
public Controler() throws IOException{
getTestData();
getTrainData();
double[][] w1 = new double[][] { { 0.2, 0.3, 0.4, 0.1 },
{ 0.3, 0.4, 0.2, 0.4 }, { 0.4, 0.8, 0.9, 0.3 }};
double[][] w2 = new double[][] { { 0.3, 0.6, 0.7 }, { 0.1, 0.3, 0.7 } };
double[] b1 = new double[] { 0.2, 0.4, 0.5 };
double[] b2 = new double[] { 0.1, 0.5 };
bp = new BP(w1,w2,b1,b2);
bp.train(p_train, t_train);
int a = 0;
for(int i =0;i<p_test.length;i++){
double[] a2 = bp.divide(p_test[i], t_test[i]);
int t0 = (int)(t_test[i][0])*2+(int)(t_test[i][1]);
int t1 = (int)(a2[0])*2+(int)(a2[1]);
boolean equals = t1==t0;
if(equals)a++;
System.out.println("expected:"+t0+"/t"+"output:"+t1+"/t"+equals);
}
a = (int)(a/50.0*100);
System.out.println(a);
}
private void getTestData() throws IOException{
String fileName = "testData.txt";
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(fileName));
} catch (FileNotFoundException e) {
br.close();
e.printStackTrace();
}
ArrayList al = new ArrayList();
String s = null;
while((s=br.readLine())!=null){
al.add(s);
}
br.close();
p_test = new double[al.size()][4];
t_test = new double[al.size()][2];
double[] maxData = new double[]{0,0,0,0};
for(int i =0;i<al.size();i++){
String[] temp = al.get(i).toString().split(" ");;
for(int j =0;j<4;j++){
p_test[i][j] = Double.parseDouble(temp[j+1]);
if(p_test[i][j]>maxData[j])maxData[j] = p_test[i][j];
}
int d = Integer.parseInt(temp[0]);
switch (d){
case 0:
t_test[i][0] = 0;
t_test[i][1] = 0;
break;
case 1:
t_test[i][0] = 0;
t_test[i][1] = 1;
break;
case 2:
t_test[i][0] = 1;
t_test[i][1] = 0;
break;
default:
t_test[i][0] = 1;
t_test[i][1] = 1;
break;
}
}
for(int i =0;i<p_test.length;i++){
for(int j =0;j<4;j++){
p_test[i][j] /= maxData[j];
}
}
}
private void getTrainData() throws IOException{
String fileName = "trainData.txt";
BufferedReader br = null;
try {
br = new BufferedReader(new FileReader(fileName));
} catch (FileNotFoundException e) {
br.close();
e.printStackTrace();
}
ArrayList al = new ArrayList();
String s = null;
while((s=br.readLine())!=null){
al.add(s);
}
p_train = new double[al.size()][4];
t_train = new double[al.size()][2];
double[] maxData = new double[]{0,0,0,0};
for(int i =0;i<al.size();i++){
String[] temp = al.get(i).toString().split(" ");;
for(int j =0;j<temp.length-1;j++){
p_train[i][j] = Double.parseDouble(temp[j+1]);
if(p_train[i][j]>maxData[j])maxData[j] = p_train[i][j];
}
int d = Integer.parseInt(temp[0]);
switch (d){
case 0:
t_train[i][0] = 0;
t_train[i][1] = 0;
break;
case 1:
t_train[i][0] = 0;
t_train[i][1] = 1;
break;
case 2:
t_train[i][0] = 1;
t_train[i][1] = 0;
break;
default:
t_train[i][0] = 1;
t_train[i][1] = 1;
break;
}
}
for(int i =0;i<p_train.length;i++){
for(int j =0;j<4;j++){
p_train[i][j] /= maxData[j];
}
}
}
public static void main(String[] args) throws Exception{
Controler ctr =new Controler();
}
}