作者:桂子山下一棵草 email: slowguy@qq.com
题目:
表一: 澳大利亚野兔眼睛晶状体重量与年龄的对应关系
编号 | 年龄(天) | 重量(mg) | 年龄(天) | 重量(mg) | 年龄(天) | 重量(mg) | 年龄(天) | 重量(mg) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | 15 15 15 18 28 29 37 37 44 50 50 60 61 64 65 65 72 75 | 21.66 22.75 22.3 31.25 44.79 40.55 50.25 46.88 52.03 63.47 61.13 81 73.09 79.09 79.51 65.31 71.9 86.1 | 75 82 85 91 91 97 98 125 142 142 147 147 150 159 165 183 192 195 | 94.6 92.5 105 101.7 102.9 110 104.3 134.9 130.68 140.58 155.3 152.2 144.5 142.15 139.81 153.22 145.72 161.1 | 218 218 219 224 225 227 232 232 237 246 258 276 285 300 301 305 312 317 | 174.18 173.03 173.54 178.86 177.68 173.73 159.98 161.29 187.07 176.13 183.4 186.26 189.66 186.09 186.7 186.8 195.1 216.41 | 338 347 354 357 375 394 513 535 554 591 648 660 705 723 756 768 860 | 203.23 188.38 189.7 195.31 202.63 224.82 203.3 209.7 233.9 234.7 244.3 231 242.4 230.77 242.57 232.12 246.7 |
澳大利亚野兔眼睛晶状体的重量为年龄的函数。利用BP算法,设计一个多层感知器,为表中的数据集提供一个非线性逼近,并测试其泛化能力。
算法源码:
package com.lwm.cn.althom;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.GregorianCalendar;
import java.util.Random;
public class BackProp {
private int randomPrecision = 8; // 生成double型随机数的精度,默认为6位小数
private int input_dimension; // 输入向量的维度
private int output_dimension; // 输出向量的维数
private int mid_dimension; // 隐层结点的个数
private double[][] V; // 输入层到隐层的权值矩阵
private double[][] W; // 隐层到输出层的权值矩阵
private double[] inputArray; // 输入层向量
private double[] midArray; // 隐层输出向量
private double[] outputArray; // 输出层向量
private double[] teacherArray; // 期望层向量
private double mid_Threshold; // 隐层阈值
private double out_Threshold; // 输出层阈值
private double[] midError; // 隐层的误差
private double[] outError; // 输出层的误差
private double totalError = 0.0;
private double outPrecision; // 要达到的精度
private double learnRate; // 学习的速率
private int trainTotal = 3000; // 学习1000次
private boolean isQualify = false; // 用于判断是不是达到精度要求
private ArrayList<SampleNode> trainArray = new ArrayList<SampleNode>(100); // 存入训练集
private ArrayList<SampleNode> testArray = new ArrayList<SampleNode>(100); // 存放测试集
private BufferedWriter bw = null; // 用于将学习和测试过程写于文件
Date startTime;
// SampleNode sample;
// Math.random()
public BackProp(double[][] v, double[][] w, int input_dimension,
int output_dimension, int mid_dimension) {
super();
V = v;
W = w;
this.input_dimension = input_dimension;
this.output_dimension = output_dimension;
this.mid_dimension = mid_dimension;
}
/**
* 默认构造函数 ,对于本次实验,输入向量只有一个,输出也只有一个. 隐层结点的个数默认为4
*
*/
public BackProp() {
input_dimension = 1;
output_dimension = 1;
mid_dimension = 8;
inputArray = new double[input_dimension];
teacherArray = new double[output_dimension];
midArray = new double[mid_dimension];
outputArray = new double[output_dimension];
V = new double[input_dimension][mid_dimension];
W = new double[mid_dimension][input_dimension];
midError = new double[mid_dimension];
outError = new double[output_dimension];
}
/**
* 初始化函数 我认为一个完整的BP算法应该具备通用性,可以任意设置输入结点个数和隐层的层数及每一层的结点个数
* 初始化权值矩阵V和W,每个元素的值均为0-1之间的六位小数
*/
public void init()
{
// 记录程序开始时间及结束时间,以开始时间命名一个文件,用来保存学习和测试结果.
startTime = new Date();
SimpleDateFormat sdf = new SimpleDateFormat("yyyy年MM月dd日HH时mm分ss秒");
String timeStr = sdf.format(startTime);
String filePathName = "E:" + File.separator + timeStr + ".txt";
try
{
bw = new BufferedWriter(new FileWriter(filePathName));
bw.write("程序开始时间:" + timeStr + "\n");
} catch (IOException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
}
mid_Threshold = MathExtend.round(Math.random(), randomPrecision); // 初始化隐层的阈值
out_Threshold = MathExtend.round(Math.random(), randomPrecision); // 初始化输出层的阈值
// 初始化V矩阵
for (int i = 0; i < input_dimension; i++)
for (int j = 0; j < mid_dimension; j++)
V[i][j] = MathExtend.round(Math.random(), randomPrecision);
// 初始化W矩阵
for (int i = 0; i < mid_dimension; i++)
for (int j = 0; j < output_dimension; j++)
W[i][j] = MathExtend.round(Math.random(), randomPrecision);
// 置总的误差为0,学习率为0-1之间的小数,网络训练后达到的精度为一正小数
totalError = 0.0;
learnRate = MathExtend.round(Math.random(), randomPrecision);
// learnRate = 0.12;
outPrecision = MathExtend.round(Math.random(), randomPrecision);
try
{
StringBuilder sb = new StringBuilder();
sb.append("本次实验随机生成的学习率: " + learnRate);
sb.append("\n");
sb.append("期望达到的精度为: " + outPrecision);
sb.append("\n");
bw.write(sb.toString());
} catch (IOException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
}
getTrainData(); // 取得训练集
getTestData(); // 取得测试集
normalized(); // 归一化
}
/**
* @author Administrator 输入层向隐层,隐层向输出层的传播
*
*/
public void finish()
{
// Date endDate = new Date() ;
try
{
bw.close();
} catch (IOException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
}
}
public void forword()
{
int i, j;
double temp_sum ; // 用于向量的内积
// 输出层到隐层
for (i = 0; i < mid_dimension; i++)
{
temp_sum = 0.0 ; //初始化为0
for (j = 0; j < input_dimension; j++)
temp_sum += V[j][i] * inputArray[j];
temp_sum = temp_sum - mid_Threshold;
midArray[i] = 1.0 / (1 + Math.exp(-temp_sum));
}
// 隐层到输出层
for (i = 0; i < output_dimension; i++)
{
temp_sum = 0.0; // 初始化
for (j = 0; j < mid_dimension; j++)
temp_sum = W[j][i] * midArray[j];
temp_sum = temp_sum - out_Threshold;
outputArray[i] = 1.0 / (1 + Math.exp(-temp_sum));
}
// 计算误差,累加起来,
temp_sum = 0.0;
for (i = 0; i < output_dimension; i++)
{
temp_sum = teacherArray[i] - outputArray[i]; // 注意中,本设计中output_dimension=1的
totalError += temp_sum * temp_sum / 2;
}
// printResult();
}
private void printResult()
{
/*
* StringBuilder sb = new StringBuilder() ;
* sb.append("输入数据:"+inputArray[0]); sb.append("
* 实际输出数据:"+outputArray[0]); sb.append(" 期望输出数据为:"+teacherArray[0]) ;
* sb.append("\\n") ; try { bw.write(sb.toString()); } catch
* (IOException e) { // TODO Auto-generated catch block
* e.printStackTrace(); }
*/
System.out.print("输入数据:" + inputArray[0]);
System.out.print(" 实际输出数据:" + outputArray[0]);
System.out.println(" 期望输出数据为:" + teacherArray[0]);
}
/**
* 反向调整权值矩阵
*/
public void adjustWeight()
{
double temp_sum = 0.0;
int i, j;
// 计算各层的误差信号 输出层
for (i = 0; i < output_dimension; i++)
{
outError[i] = (teacherArray[i] - outputArray[i])
* (1 - outputArray[i]) * outputArray[i];
}
// 隐层误差
for (i = 0; i < mid_dimension; i++)
{
temp_sum=0.0d ;
for (j = 0; j < output_dimension; j++)
temp_sum += outError[j] * W[i][j];
midError[i] = temp_sum * (1 - midArray[i]) * midArray[i];
}
// 调整W权值矩阵
for (i = 0; i < mid_dimension; i++)
{
for (j = 0; j < output_dimension; j++)
W[i][j] += learnRate * outError[j] * midArray[i];
}
// 调整V权值矩阵
for (i = 0; i < input_dimension; i++)
for (j = 0; j < mid_dimension; j++)
V[i][j] += learnRate * midError[j] * inputArray[i];
}
public void getTrainData()
{
String filePathName = "E:" + File.separator + "traindata.txt";
BufferedReader br = null;
try
{
br = new BufferedReader(new FileReader(filePathName));
} catch (FileNotFoundException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
}
String s = null;
SampleNode sNode = null;
try
{
while ((s = br.readLine()) != null)
{
String data[] = s.trim().split("[\\s]+");
if (data == null || data.length != 2)
{
System.out.println("traindata文件数据有问题!");
return;
}
double in = Double.parseDouble(data[0]);
double hope = Double.parseDouble(data[1]);
sNode = new SampleNode(in, hope);
trainArray.add(sNode);
// trainArray.
}
} catch (IOException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
}
trainArray.trimToSize();
}
public void getTestData()
{
String fileName = "E:" + File.separator + "testdata.txt";
BufferedReader br = null;
try
{
br = new BufferedReader(new FileReader(fileName));
} catch (FileNotFoundException e)
{
// TODO Auto-generated catch block
System.out.println("testdata.txt文件不存在");
e.printStackTrace();
}
String s = null;
SampleNode sNode = null;
try
{
while ((s = br.readLine()) != null)
{
String data[] = s.trim().split("[\\s]+");
if (data == null || data.length != 2)
{
System.out.println("testdata文件数据有问题!");
return;
}
double in = Double.parseDouble(data[0]);
double hope = Double.parseDouble(data[1]);
sNode = new SampleNode(in, hope);
testArray.add(sNode);
// trainArray.
}
} catch (IOException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
}
testArray.trimToSize();
}
/**
* 对输入数据进行归一化处理,将输入数据限制在[0,1]区间内
*
*/
private void normalized()
{
if (trainArray == null || trainArray.size() == 0 || testArray == null
|| testArray.size() == 0)
{
System.out.println("测试数据或者训练数据有问题!");
return;
}
SampleNode sNode = null;
// 训练数据归一化
int size = trainArray.size();
int i = 0;
while (i < size)
{
sNode = trainArray.get(i);
double in = sNode.in;
double hope = sNode.hope;
in /= 1000.0; // 归一
hope /= 250.0;
sNode.in = in;
sNode.hope = hope;
trainArray.set(i, sNode);
i++;
}
size = testArray.size();
i = 0;
// 测试数据归一化
while (i < size)
{
sNode = testArray.get(i);
double in = sNode.in;
double hope = sNode.hope;
in /= 1000.0; // 归一
hope /= 250.0;
sNode.in = in;
sNode.hope = hope;
trainArray.set(i, sNode);
i++;
}
}
public void startTrain()
{
if (trainArray == null || trainArray.size() == 0)
return;
System.out.println("训练开始");
System.out.println("当前学习速率:" + learnRate);
System.out.println("期望精度为:" + outPrecision);
int trainConunter = 0;
while (trainConunter++ < trainTotal)
{
System.out.println("第" + trainConunter + "次训练开始:");
for (SampleNode sNode : trainArray)
{
// 说明:在本设计中inputArray,和teacherArray虽然都是数组,但均只有一个元素.
// 本人为了综合虑,才将设为数组的.
inputArray[0] = sNode.in;
teacherArray[0] = sNode.hope;
forword(); // 学习一次
printResult();
} // 至此,所有训练集全部学习完毕,下面应该进行权值调整.
/*System.out.println("此次学习后,总的误差为:" + totalError);
StringBuilder sb = new StringBuilder();
sb.append("第" + trainConunter);
sb.append("次学习后,总的误差为:" + totalError);
sb.append("\n");*/
try
{
// bw.write(sb.toString());
bw.write(Double.toString(totalError)+"\n") ;
} catch (IOException e)
{
e.printStackTrace();
}
adjustWeight(); // 集体主义原则来调整权值
if (totalError <= outPrecision)
{
isQualify = true; // 置标志位为真,表示达到要求
break;
}
totalError = 0.0; // 误差初化
}
Date endTime = new Date();
SimpleDateFormat sdf = new SimpleDateFormat("yyyy年MM月dd日HH时mm分ss秒");
String endtimeStr = sdf.format(endTime);
long gap = endTime.getTime() - this.startTime.getTime();
StringBuilder sb = new StringBuilder();
try
{
sb.append("训练结束时间为:" + endtimeStr);
sb.append("\n");
sb.append("总的学习时间为:" + gap);
sb.append("微秒\n");
sb.append("********************************************\n");
bw.write(sb.toString());
} catch (IOException e1)
{
// TODO Auto-generated catch block
e1.printStackTrace();
}
if (!isQualify)
{
System.out.println("达到训练次数,训练结束!");
try
{
bw.write("训练次数:" + trainTotal + "次\n");
} catch (IOException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
}
} else
{
try
{
bw.write("达到精度要求,学习完毕!\n");
} catch (IOException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
}
System.out.println("达到要求的精度,训练结束!");
}
}
public void startTest()
{
if (testArray == null || testArray.isEmpty() == true)
return;
for (SampleNode sNode : testArray)
{
StringBuilder sb = new StringBuilder();
inputArray[0] = sNode.in;
teacherArray[0] = sNode.hope;
forword();
sb.append("输入测试数据: " + inputArray[0]);
sb.append(" 实际输出:" + outputArray[0]);
sb.append(" 期望输出:" + teacherArray[0]);
sb.append("\n");
try
{
bw.write(sb.toString());
} catch (IOException e)
{
// TODO Auto-generated catch block
e.printStackTrace();
}
printResult();
}
}
}
测试输出结果如下图:
程序运行一次的收敛图如下图:
<!--EndFragment-->