本文是对懒惰的gler–Java实现BP神经网络完成 Iris 数据分类:http://blog.csdn.net/u010858605/article/details/72898178 这篇博客的理解和对一些小问题的改进。原文的相关描述请点击上面的链接即可。
(1)数据选取
由于采用Iris 鸢尾花数据集,该数据集一共有150条记录,选取 Iris 数据集中的120条数据作为训练集(train.txt),剩余的30条数据作为测试集(test.txt)。
(2)java-code的理解
一共包含三个类: BPNN.java 、DataUtil.java 、Test.java
(3)代码的问题
3.1 在数据写入reslt.txt中的时候,判定的花型全是”Iris-setosa”,与事实不符合。
3.2 如果训练次数超过MaxTrain则没有给出判断的条件
(4)改进
4.1 – BPNN.java
import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;
class BPNN {
// private static int LAYER = 3; // 三层神经网络
private static int NodeNum = 10; // 每层的最多节点数
private static final int ADJUST = 5; // 隐层节点数调节常数
private static final int MaxTrain = 2000; // 最大训练次数
private static final double ACCU = 0.015; // 每次迭代允许的误差 iris:0.015
private double ETA_W = 0.5; // 权值学习效率0.5
private double ETA_T = 0.5; // 阈值学习效率0.5
private double accu;
// 附加动量项
//private static final double ETA_A = 0.3; // 动量常数0.1
//private double[][] in_hd_last; // 上一次的权值调整量
//private double[][] hd_out_last;
private int in_num; // 输入层节点数
private int hd_num; // 隐层节点数
private int out_num; // 输入出节点数
private ArrayList<ArrayList<Double>> list = new ArrayList<>(); // 输入输出数据
private double[][] in_hd_weight; // BP网络in-hidden突触权值
private double[][] hd_out_weight; // BP网络hidden_out突触权值
private double[] in_hd_th; // BP网络in-hidden阈值
private double[] hd_out_th; // BP网络hidden-out阈值
private double[][] out; // 每个神经元的值经S型函数转化后的输出值,输入层就为原值
private double[][] delta; // delta学习规则中的值
// 获得网络三层中神经元最多的数量
public int GetMaxNum() {
return Math.max(Math.max(in_num, hd_num), out_num);
}
// 设置权值学习率
public void SetEtaW() {
ETA_W = 0.5;
}
// 设置阈值学习率
public void SetEtaT() {
ETA_T = 0.5;
}
// BPNN训练
public void Train(int in_number, int out_number, ArrayList<ArrayList<Double>> arraylist) throws IOException {
list = arraylist;
in_num = in_number;
out_num = out_number;
GetNums(in_num, out_num); // 获取输入层、隐层、输出层的节点数
// SetEtaW(); // 设置学习率
// SetEtaT();
InitNetWork(); // 初始化网络的权值和阈值
int datanum = list.size(); // 训练数据的组数
int createsize = GetMaxNum(); // 比较每一层的节点数,取max
out = new double[3][createsize]; // 创建输出数组 out[3][7]
//训练次数为MaxTrain以内,如果训练次数超过MaxTrain则没有给出判断的条件
for (int iter = 0; iter < MaxTrain; iter++) {
for (int cnd = 0; cnd < datanum; cnd++) {
// 第一层输入节点赋值 out[0][4]
for (int i = 0; i < in_num; i++) {
//list.get(cnd).get(i) 取样本数据的第cnd组中第i个数据放入到out[0][i]中
out[0][i] = list.get(cnd).get(i); // 为输入层节点赋值,其输入与输出相同
}
Forward(); // 前向传播
Backward(cnd); // 误差反向传播
}
System.out.println("This is the " + (iter + 1) + " th trainning NetWork !");
accu = GetAccu();
System.out.println("All Samples Accuracy is " + accu);
if (accu < ACCU)
break;
}
}
// 获取输入层、隐层、输出层的节点数,in_number、out_number分别为输入层节点数和输出层节点数
public void GetNums(int in_number, int out_number) {
in_num = in_number;
out_num = out_number;
hd_num = (int) Math.sqrt(in_num + out_num) + ADJUST;
if (hd_num > NodeNum)
hd_num = NodeNum; // 隐层节点数不能大于最大节点数
}
// 初始化网络的权值和阈值
public void InitNetWork() {
// 初始化上一次权值量,范围为-0.5-0.5之间
//in_hd_last = new double[in_num][hd_num];
//hd_out_last = new double[hd_num][out_num];
in_hd_weight = new double[in_num][hd_num];
for (int i = 0; i < in_num; i++)
for (int j = 0; j < hd_num; j++) {
int flag = 1; // 符号标志位(-1或者1)
if ((new Random().nextInt(2)) == 1)
flag = 1;
else
flag = -1;
// New Random.nextDouble()的取值范围: [0,1.0)
in_hd_weight[i][j] = ( new Random().nextDouble() / 2 ) * flag; // 初始化in-hidden的权值
//in_hd_last[i][j] = 0;
}
hd_out_weight = new double[hd_num][out_num];
for (int i = 0; i < hd_num; i++)
for (int j = 0; j < out_num; j++) {
int flag = 1; // 符号标志位(-1或者1)
if ((new Random().nextInt(2)) == 1)
flag = 1;
else
flag = -1;
hd_out_weight[i][j] = (new Random().nextDouble() / 2) * flag; // 初始化hidden-out的权值
//hd_out_last[i][j] = 0;
}
// 阈值均初始化为0
// 输入层不处理数据,只接收数据,所以不设置阈值
in_hd_th = new double[hd_num];
for (int k = 0; k < hd_num; k++)
in_hd_th[k] = 0;
hd_out_th = new double[out_num];
for (int k = 0; k < out_num; k++)
hd_out_th[k] = 0;
}
/* // 计算单个样本的误差
public double GetError(int cnd) {
double ans = 0;
for (int i = 0; i < out_num; i++)
{
System.out.println(out[2][i]);
ans += 0.5 * (out[2][i] - list.get(cnd).get(in_num + i)) * (out[2][i] - list.get(cnd).get(in_num + i));
}
return ans;
}*/
// 计算所有样本的平均精度
public double GetAccu() {
double ans = 0;
int num = list.size();
for (int i = 0; i < num; i++) {
int m = in_num;
for (int j = 0; j < m; j++)
out[0][j] = list.get(i).get(j);
Forward();
int n = out_num;
for (int k = 0; k < n; k++){
//定义了输入与输出之间的平方误差
//System.out.println(list.get(i).get(in_num + k));
//System.out.println(out[2][k]);
ans += 0.5 * (list.get(i).get(in_num + k) - out[2][k]) * (list.get(i).get(in_num + k) - out[2][k]);
}
}
return ans / num;
}
// 前向传播
public void Forward() {
/**
* 计算隐层节点的输出值
* v = 求和( 每个输入层数据 * 每个隐层的权重 ) + 对应 隐层 的阈值
* in_hd_weight[4][7] out[0][4] in_hd_th[7]
*/
for (int j = 0; j < hd_num; j++) {
double v = 0;
for (int i = 0; i < in_num; i++)
v += in_hd_weight[i][j] * out[0][i];
v += in_hd_th[j];
out[1][j] = Sigmoid(v);
}
/**
* 计算输出层节点的输出值
* v = 求和( 每个隐层输出数据 * 每个输出层的权重 ) + 对应 输出层 的阈值
* hd_out_weight[7][3] out[1][3] hd_out_th[3]
*/
for (int j = 0; j < out_num; j++) {
double v = 0;
for (int i = 0; i < hd_num; i++)
v += hd_out_weight[i][j] * out[1][i];
v += hd_out_th[j];
out[2][j] = Sigmoid(v);
}
}
// 误差反向传播 = 计算权值调整量 + 更新BP神经网络的权值和阈值
public void Backward(int cnd) {
CalcDelta(cnd); // 计算权值调整量
UpdateNetWork(); // 更新BP神经网络的权值和阈值
}
// 计算delta调整量
public void CalcDelta(int cnd) {
int createsize = GetMaxNum(); // 比较创建数组
delta = new double[3][createsize];
// 计算输出层的delta值 cnd ( 0 - 119 )
for (int i = 0; i < out_num; i++) {
//System.out.println(list.size());
delta[2][i] = (list.get(cnd).get(in_num + i) - out[2][i]) * SigmoidDerivative(out[2][i]);
}
// 计算隐层的delta值
for (int i = 0; i < hd_num; i++) {
double t = 0;
for (int j = 0; j < out_num; j++)
t += hd_out_weight[i][j] * delta[2][j];
delta[1][i] = t * SigmoidDerivative(out[1][i]);
}
}
// 更新BP神经网络的权值和阈值
public void UpdateNetWork() {
// 隐含层和输出层之间权值和阀值调整
for (int i = 0; i < hd_num; i++) {
for (int j = 0; j < out_num; j++) {
hd_out_weight[i][j] += ETA_W * delta[2][j] * out[1][i]; // 未加权值动量项
/* 动量项
* hd_out_weight[i][j] += (ETA_A * hd_out_last[i][j] + ETA_W
* delta[2][j] * out[1][i]); hd_out_last[i][j] = ETA_A *
* hd_out_last[i][j] + ETA_W delta[2][j] * out[1][i];
*/
}
}
for (int i = 0; i < out_num; i++)
hd_out_th[i] += ETA_T * delta[2][i];
// 输入层和隐含层之间权值和阀值调整
for (int i = 0; i < in_num; i++) {
for (int j = 0; j < hd_num; j++) {
in_hd_weight[i][j] += ETA_W * delta[1][j] * out[0][i]; // 未加权值动量项
/* 动量项
* in_hd_weight[i][j] += (ETA_A * in_hd_last[i][j] + ETA_W
* delta[1][j] * out[0][i]); in_hd_last[i][j] = ETA_A *
* in_hd_last[i][j] + ETA_W delta[1][j] * out[0][i];
*/
}
}
for (int i = 0; i < hd_num; i++)
in_hd_th[i] += ETA_T * delta[1][i];
}
// 符号函数sign
public int Sign(double x) {
if (x > 0)
return 1;
else if (x < 0)
return -1;
else
return 0;
}
// 返回最大值
public double Maximum(double x, double y) {
if (x >= y)
return x;
else
return y;
}
// 返回最小值
public double Minimum(double x, double y) {
if (x <= y)
return x;
else
return y;
}
// log-sigmoid函数
public double Sigmoid(double x) {
return (double) (1 / (1 + Math.exp(-x)));
}
// log-sigmoid函数的倒数
public double SigmoidDerivative(double y) {
return (double) (y * (1 - y));
}
/* // tan-sigmoid函数
public double TSigmoid(double x) {
return (double) ((1 - Math.exp(-x)) / (1 + Math.exp(-x)));
}
// tan-sigmoid函数的倒数
public double TSigmoidDerivative(double y) {
return (double) (1 - (y * y));
}*/
// 分类预测函数
public ArrayList<ArrayList<Double>> ForeCast(
ArrayList<ArrayList<Double>> arraylist) {
ArrayList<ArrayList<Double>> alloutlist = new ArrayList<>();
ArrayList<Double> outlist = new ArrayList<Double>();
int datanum = arraylist.size();
for (int cnd = 0; cnd < datanum; cnd++) {
for (int i = 0; i < in_num; i++)
out[0][i] = arraylist.get(cnd).get(i); // 为输入节点赋值
Forward();
for (int i = 0; i < out_num; i++) {
if (out[2][i] > 0 && out[2][i] < 0.5)
out[2][i] = 0;
else if (out[2][i] > 0.5 && out[2][i] < 1) {
out[2][i] = 1;
}
outlist.add(out[2][i]);
//System.out.println( out[2][i] );
}
alloutlist.add(outlist);
outlist = new ArrayList<Double>();
outlist.clear();
}
return alloutlist;
}
}
4.2 – DataUtil.java
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
class DataUtil {
private ArrayList<ArrayList<Double>> alllist = new ArrayList<ArrayList<Double>>(); // 存放所有数据
private ArrayList<String> outlist = new ArrayList<String>(); // 存放输出数据,索引对应每个everylist的输出
private ArrayList<String> checklist = new ArrayList<String>(); //存放测试集的真实输出字符串
private int in_num = 0;
private int out_num = 0; // 输入输出数据的个数
private int type_num = 0; // 输出的类型数量
private double[][] nom_data; //归一化输入数据中的最大值和最小值
private int in_data_num = 0; //提前获得输入数据的个数
// 获取输出类型的个数
public int GetTypeNum() {
return type_num;
}
// 设置输出类型的个数
public void SetTypeNum(int type_num) {
this.type_num = type_num;
}
// 获取输入数据的个数
public int GetInNum() {
return in_num;
}
// 获取输出数据的个数
public int GetOutNum() {
return out_num;
}
// 获取所有数据的数组
public ArrayList<ArrayList<Double>> GetList() {
return alllist;
}
// 获取输出为字符串形式的数据
public ArrayList<String> GetOutList() {
return outlist;
}
// 获取输出为字符串形式的数据
public ArrayList<String> GetCheckList() {
return checklist;
}
//返回归一化数据所需最大最小值
public double[][] GetMaxMin(){
return nom_data;
}
// 读取文件初始化数据
public void ReadFile( String filepath, String sep, int flag ) throws Exception {
ArrayList<Double> everylist = new ArrayList<Double>(); // 存放每一组输入输出数据
int readflag = flag; // flag=0,train;flag=1,test
String encoding = "GBK"; //编码格式"GBK"
File file = new File(filepath);
if (file.isFile() && file.exists()) { // 判断文件是否存在
InputStreamReader read = new InputStreamReader(new FileInputStream( file ), encoding);// 考虑到编码格式
BufferedReader bufferedReader = new BufferedReader(read);
String lineTxt = null;
while ((lineTxt = bufferedReader.readLine()) != null) {
int in_number = 0;
//将每一行的数据按','截取字符串
String splits[] = lineTxt.split(sep);
if (readflag == 0) {
for (int i = 0; i < splits.length; i++)
try {
//对数据进行归一化处理
everylist.add(Normalize(Double.valueOf(splits[i]),nom_data[i][0],nom_data[i][1]));
in_number++;
} catch (Exception e) {
//outlist:存放输出数据的类型
if (!outlist.contains(splits[i]))
outlist.add(splits[i]); // 存放字符串形式的输出数据
//初始化[-,-,-,-,0.0,0.0,0.0]
for (int k = 0; k < type_num; k++) {
everylist.add(0.0);
}
// 0-3:四个属性 4-6:输出节点处理,进行one-hot编程
// outlist.indexOf(splits[i]):获取第几位的不为空
// everylist 存放着[ 0 - 6 ] 位
everylist.set(in_number + outlist.indexOf(splits[i]),1.0);
}
} else if (readflag == 1) {
for (int i = 0; i < splits.length; i++)
try {
everylist.add(Normalize(Double.valueOf(splits[i]),nom_data[i][0],nom_data[i][1]));
in_number++;
} catch (Exception e) {
checklist.add(splits[i]); // 存放字符串形式的输出数据
}
}
alllist.add(everylist); // 存放所有数据
in_num = in_number;
out_num = type_num;
everylist = new ArrayList<Double>();
everylist.clear();
}
bufferedReader.close();
}
}
//向文件写入分类结果
public void WriteFile(String filepath, ArrayList<ArrayList<Double>> list, int in_number, ArrayList<String> resultlist) throws IOException{
File file = new File(filepath);
FileWriter fw = null;
BufferedWriter writer = null;
try {
fw = new FileWriter(file);
writer = new BufferedWriter(fw);
for(int i=0;i<list.size();i++){
for(int j=0;j<in_number;j++){
writer.write(list.get(i).get(j)+",");
}
writer.write(resultlist.get(i));
writer.newLine();
}
writer.flush();
} catch (IOException e) {
e.printStackTrace();
}finally{
writer.close();
fw.close();
}
}
//学习样本归一化,找到输入样本数据的最大值和最小值
public void NormalizeData(String filepath) throws IOException{
//提前获得输入数据的个数
GetBeforIn(filepath);
int flag=1;
//nom_data存放输入节点的max和min in_data_num:4
nom_data = new double[in_data_num][2];
String encoding = "GBK";
File file = new File(filepath);
if ( file.isFile() && file.exists() ) { // 判断文件是否存在
InputStreamReader read = new InputStreamReader( new FileInputStream(file), encoding );// 考虑到编码格式
BufferedReader bufferedReader = new BufferedReader(read);
String lineTxt = null;
while ((lineTxt = bufferedReader.readLine()) != null) {
String splits[] = lineTxt.split(","); // 按','截取字符串
for (int i = 0; i < splits.length-1; i++){
if(flag==1){
nom_data[i][0]=Double.valueOf(splits[i]);
nom_data[i][1]=Double.valueOf(splits[i]);
}
else{
if(Double.valueOf(splits[i])>nom_data[i][0])
nom_data[i][0]=Double.valueOf(splits[i]);
if(Double.valueOf(splits[i])<nom_data[i][1])
nom_data[i][1]=Double.valueOf(splits[i]);
}
}
flag=0;
}
bufferedReader.close();
}
}
//归一化前获得输入数据的个数
public void GetBeforIn(String filepath) throws IOException{
String encoding = "GBK";
File file = new File(filepath);
if (file.isFile() && file.exists()) { // 判断文件是否存在
InputStreamReader read = new InputStreamReader(new FileInputStream(
file), encoding);// 考虑到编码格式
//提前获得输入数据的个数
BufferedReader beforeReader = new BufferedReader(read);
String beforetext = beforeReader.readLine();
String splits[] = beforetext.split(",");
in_data_num = splits.length-1;
beforeReader.close();
}
}
//归一化公式 -- 用于读取文件中
public double Normalize(double x, double max, double min){
double y = 0.1+0.8*(x-min)/(max-min);
return y;
}
}
4.3 – Test.java
import java.util.ArrayList;
public class Test {
public static void main(String args[]) throws Exception {
//alllist = 4 + 3 即输入和输出
ArrayList<ArrayList<Double>> alllist = new ArrayList<ArrayList<Double>>(); // 存放所有数据
ArrayList<String> outlist = new ArrayList<String>(); // 存放分类的字符串
int in_num = 0, out_num = 0; // 输入输出数据的个数
DataUtil dataUtil = new DataUtil(); // 初始化数据
dataUtil.NormalizeData("F:\\实训\\code\\BPNN_three\\data\\train.txt"); //对数据进行归一化处理
dataUtil.SetTypeNum(3); // 设置输出类型的数量
dataUtil.ReadFile("F:\\实训\\code\\BPNN_three\\data\\train.txt", ",", 0);
in_num = dataUtil.GetInNum(); // 获得输入数据的个数
out_num = dataUtil.GetOutNum(); // 获得输出数据的个数(个数代表类型个数)
alllist = dataUtil.GetList(); // 获得初始化后的数据
outlist = dataUtil.GetOutList();
//System.out.println(outlist);
System.out.print("分类的类型:");
for(int i =0 ;i<outlist.size();i++)
System.out.print(outlist.get(i)+" ");
System.out.println();
System.out.println("训练集的数量:"+alllist.size());
BPNN bpnn = new BPNN();
// 训练
System.out.println("Train Start!");
System.out.println(".............");
bpnn.Train(in_num, out_num, alllist);
System.out.println("Train End!");
// 测试
DataUtil testUtil = new DataUtil();
testUtil.NormalizeData("F:\\实训\\code\\BPNN_three\\data\\test.txt");
testUtil.SetTypeNum(3); // 设置输出类型的数量
testUtil.ReadFile("F:\\实训\\code\\BPNN_three\\data\\test.txt", ",", 1);
ArrayList<ArrayList<Double>> testList = new ArrayList<ArrayList<Double>>();
ArrayList<ArrayList<Double>> resultList = new ArrayList<ArrayList<Double>>();
ArrayList<String> normallist = new ArrayList<String>(); // 存放测试集标准的输出字符串
ArrayList<String> resultlist = new ArrayList<String>(); // 存放测试集计算后的输出字符串
int right = 0; // 分类正确的数量
int type_num = 0; // 类型的数量
int all_num = 0; //测试集的数量
type_num = outlist.size();
testList = testUtil.GetList(); // 获取测试数据
normallist = testUtil.GetCheckList();
//int errorcount = 0; // 分类错误的数量
resultList = bpnn.ForeCast(testList); // 测试
all_num = resultList.size();
//resultList:[-,-,-] normallist:[-] outlist:[-,-,-]
//System.out.println(resultList);
//System.out.println(normallist);
//System.out.println(outlist);
//临时存放结果
ArrayList<String> Temp = new ArrayList<String>();
//resultList=[30][3] 这里的输出有问题???解决方式:增加一个临时存放结果的数组
for (int i = 0; i < resultList.size(); i++) {
String checkString = "unknow";
for (int j = 0; j < type_num; j++) {
//System.out.println(resultList.get(i).get(j));
if( resultList.get(i).get(j) == 1.0 ){
//System.out.println(outlist.get(j));
checkString = outlist.get(j);
Temp.add(checkString);
}
else{
resultlist.add(checkString);
}
}
/* if(checkString.equals("unknow"))
errorcount++;*/
//normallist.get(i)为实际的判定值
if(checkString.equals(normallist.get(i)))
right++;
}
//System.out.println(Temp);
testUtil.WriteFile("F:\\实训\\code\\BPNN_three\\data\\result.txt",testList,in_num,Temp);
System.out.println("测试集的数量:"+ all_num );
System.out.println("分类正确的数量:"+ right );
//System.out.println("分类正确的数量:"+(new Double(right)).intValue());
System.out.println("算法的分类正确率为:"+ (new Double( (double) right/all_num )));
System.out.println("分类结果存储在:F:\\实训\\code\\BPNN_three\\data\\result.txt");
//bpnn.GetError(1);
}
}
(5)运行截图
…………………………………………….一共30组
(6)参考资料
1.原作者博客:http://blog.csdn.net/u010858605/article/details/72898178
2.数据集下载:http://archive.ics.uci.edu/ml/index.php
3.归一化处理:http://www.cnblogs.com/chaosimple/p/3227271.html
4.one-hot编程:http://www.cnblogs.com/daguankele/p/6595470.html
5.delta学习:http://blog.csdn.net/u012562273/article/details/56297648
6.机器学习之BP神经网络(三) : https://zhuanlan.zhihu.com/p/28993795