在前面的几篇文章中,都对手写识别进行了一些讲解,这里主要是介绍一下通过另外一种方法来进行识别---------朴素贝叶斯。。自己也是处于机器学习路上的一名新手,如果有什么讲解不对的话,欢迎大家进行交流,可以把建议写到下面。。。。
好了,不多说,进入正题。。朴素贝叶斯,我相信,搞机器学习的人都不会陌生,关于它的一些基本概念我就不说了,如果还有什么不明白的地方,可以去百度查查它的理论知识。我主要就是对于手写识别来进行针对性的讲解。
就把朴素贝叶斯中,最为关键的公式贴出来:
一:关于训练集数据
这部分,我在前面的文章中,进行了讲解,而且数据集我也分享到了百度云,如果有需要的可以翻看一下前面的那篇神经网络的文章进行下载。
二:朴素贝叶斯的实践
这里讲解一下,大概的步骤吧。其实了解朴素贝叶斯的话,应该很好理解如何进行实施的,毕竟这算法的优点就是通过概率来预测的这么一种简单的方法。 (就把自己做课程报告中PPT写的东西贴出来)
我想,如果了解朴素贝叶斯的基本理念再加上上面的一些提示,那么应该就知道如何进行实施了。。下面就是代码(Java语言)::
1:读取训练集数据(.csv后缀的文件)
package beiyesifenleiqi;
/*
* 读取后缀为csv的excell文件
*
*/
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
public class CSVFileUtil {
private String fileName = null;
private BufferedReader br = null;
private List<String> list = new ArrayList<String>();
public CSVFileUtil() {
}
public CSVFileUtil(String fileName) throws Exception {
this.fileName = fileName;
br = new BufferedReader(new FileReader(fileName));
String stemp;
while ((stemp = br.readLine()) != null) {
list.add(stemp);
}
}
public List getList() {
return list;
}
/**
* 获取行数
* @return
*/
public int getRowNum() {
return list.size();
}
/**
* 获取列数
* @return
*/
public int getColNum() {
if (!list.toString().equals("[]")) {
if (list.get(0).toString().contains(",")) {// csv为逗号分隔文件
return list.get(0).toString().split(",").length;
} else if (list.get(0).toString().trim().length() != 0) {
return 1;
} else {
return 0;
}
} else {
return 0;
}
}
/**
* 获取制定行
* @param index
* @return
*/
public String getRow(int index) {
if (this.list.size() != 0) {
return (String) list.get(index);
} else {
return null;
}
}
/**
* 获取指定列
* @param index
* @return
*/
public String getCol(int index) {
if (this.getColNum() == 0) {
return null;
}
StringBuffer sb = new StringBuffer();
String tmp = null;
int colnum = this.getColNum();
if (colnum > 1) {
for (Iterator it = list.iterator(); it.hasNext();) {
tmp = it.next().toString();
sb = sb.append(tmp.split(",")[index] + ",");
}
} else {
for (Iterator it = list.iterator(); it.hasNext();) {
tmp = it.next().toString();
sb = sb.append(tmp + ",");
}
}
String str = new String(sb.toString());
str = str.substring(0, str.length() - 1);
return str;
}
/**
* 获取某个单元格
* @param row
* @param col
* @return
*/
public String getString(int row, int col) {
String temp = null;
int colnum = this.getColNum();
if (colnum > 1) {
temp = list.get(row).toString().split(",")[col];
} else if(colnum == 1){
temp = list.get(row).toString();
} else {
temp = null;
}
return temp;
}
public void CsvClose()throws Exception{
this.br.close();
}
}
2: 构建朴素贝叶斯
package beiyesifenleiqi;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
/*
* 贝叶斯分类器
*/
public class BeiYeSIFenLeiQi {
public int[] splicStyleResult; //分类的结果10种,即0-9
public int[][] numberEveryLie; //每一个数字对应的每一列的概率
public double[][] gailvEveryHang; //每个数字的每一行中的概率
public double[] gailvStyleResult; //训练集中每个数字出现的概率
public int totalResultNumber; //记录总的输入结果的个数就是10,便于计算概率而已
public int totalHangNumber; //记录总的行数,其实就是784
private int xunlianjigeshu;
public BeiYeSIFenLeiQi(int styleResult,int totalLie){
splicStyleResult=new int[styleResult]; //初始化所需要进行分类的个数
numberEveryLie=new int[styleResult][totalLie]; //也就是10和784
gailvEveryHang=new double[styleResult][totalLie]; //也就是10和784
gailvStyleResult=new double[styleResult]; //每个数字的概率
totalResultNumber=styleResult; //得到结果个数
totalHangNumber=totalLie; //得到列数
}
/*
* 设置训练集的总个数
*/
public void setXunLianNumber(int geshu){
xunlianjigeshu=geshu;
}
/*
* 计算在训练集中,每一个结果标签对应的个数
*/
public void addeveryResultNumber(int numberresult, int currentnumber){
splicStyleResult[currentnumber]+=numberresult; //对应的标签个数加上线程处理后的个数
}
/*
* 找到对应的结果数字,并且将相应的列为1值的索引数加1,主要用来后面算概率
*/
public void addEveryLieNumber(int numberresult,int liesuoyin,int number){
numberEveryLie[numberresult][liesuoyin]+=number; //将像素数值为1的对应的数字的索引个数加1,(784中)
}
/*
* 更新处理完概率后每个线程计算出来的概率总和
*
*/
public void updataThreadComputeGaiLv(double[] gailvResult , double[][] gaileveryElement){
//更新数字的概率
for(int i=0 ; i<totalResultNumber ; i++){
gailvStyleResult[i]+=gailvResult[i];
for(int j=0 ; j<totalHangNumber ; j++){
gailvEveryHang[i][j]+=gaileveryElement[i][j];
}
}
}
/*
* 打印概率结果
*/
public void printfResult(){
for(int i=0;i<10;i++){
// System.out.println(gailvStyleResult[i]);//打印每个数字的概率
// for(int j=0;j<784;j++){
// System.out.print(gailvEveryHang[i][j]+" ");//打印每个数字的每个像素的概率
// }
System.out.println();
}
}
/*
* 计算测试集的结果
*/
public int computeYuCeResult(double[] binary) {
double[] everyGailv=new double[10];
//计算每个数字的可能性概率
for(int currennumber=0;currennumber<10;currennumber++){
everyGailv[currennumber]=gailvStyleResult[currennumber]; //得到在训练集中该数字出现的概率
for(int suoyin=0;suoyin<784;suoyin++){
if(binary[suoyin]==0){ //表示该像素上没有可能性
everyGailv[currennumber]=everyGailv[currennumber]*(1-gailvEveryHang[currennumber][suoyin]);//贝叶斯分类的概率计算,因为该位置为0,则表示与训练集汇总的可能性较大
}
else if(binary[suoyin]==1){ //表示该位置出现了,则按之后算好的概率进行计算
everyGailv[currennumber]=everyGailv[currennumber]*gailvEveryHang[currennumber][suoyin];
}
}
}
//比较存储的10个数字中,概率最大的是哪个,则表示最有可能的预测就是哪个数字
double sumGailv=0; //总的概率
for(int i=0;i<10;i++){
sumGailv=sumGailv+everyGailv[i];
}
for(int j=0;j<10;j++){
everyGailv[j]=everyGailv[j]/sumGailv; //得到权重的百分比
}
double maxGailv=everyGailv[0];
int maxSuoyin=0;
for(int max=1;max<10;max++){
if(maxGailv<everyGailv[max]){
maxGailv=everyGailv[max];
maxSuoyin=max;
}
}
return maxSuoyin; //返回预测的数字
}
/*
* 加载之前已经训练过的数据
*/
public void loadPreviousData(File writePath) throws Exception {
FileInputStream in=new FileInputStream(writePath);
InputStreamReader isr=new InputStreamReader(in, "UTF-8"); //防止乱码
BufferedReader br = new BufferedReader(isr);
String currenline ="";
int suoyin=0;
try {
while((currenline=br.readLine())!=null){
String[] fengeeverynumber=currenline.split(","); //得到每一行的每小格的数据
int totallength=fengeeverynumber.length;
gailvStyleResult[suoyin] = Double.valueOf(fengeeverynumber[0]);
for(int i=1;i<totallength-1;i++){
gailvEveryHang[suoyin][i-1]=Double.valueOf(fengeeverynumber[i]);
}
suoyin++;
}
} catch (IOException e) {
e.printStackTrace();
}
finally{
br.close();
}
}
/*
* 把每个概率写入到Txt文件中,方便后面读
*/
public void writeEveryGailv(File writePath) throws IOException {
int lengthdata = gailvStyleResult.length; //得到数字概率的个数(其实就是10个)
FileWriter fw= new FileWriter(writePath);
BufferedWriter bw= new BufferedWriter(fw);
for(int resultsuoyin=0;resultsuoyin<lengthdata;resultsuoyin++){
bw.write(gailvStyleResult[resultsuoyin]+",");
for(int i=0;i<784;i++){
bw.write(gailvEveryHang[resultsuoyin][i]+","); //写入数据
}
bw.write("\t\n"); //加个换行(一个数据一行)
}
bw.close();
}
}
3:线程类(因为数据太多了,就用多线程进行了读取,这样来减少读取的时间)
package beiyesifenleiqi;
import java.util.concurrent.CountDownLatch;
import shenjingwangluo2.CSVFileUtil;
import beiyesifenleiqi.Text;
public class startDealThread implements Runnable{
int startindex;
int overindex;
int trainResultNumber;
int xunlianjigeshu;
int totalHangNumber;
CSVFileUtil resultData;
CSVFileUtil trainData;
int[] getnumber;
int[][] everyNumberGeshu;
double[] gailvResultNumber;
double[][] gailvEveryHangLie;
CountDownLatch countDownLatch;
Text manythread;
public startDealThread(Text manythread, CountDownLatch countDownLatch,
int suoyin,int oversuoyin,CSVFileUtil util,CSVFileUtil util2,int totalResultNumber,int totalLieNumber) {
startindex=suoyin;
this.overindex=oversuoyin;
resultData=util;
trainData=util2;
totalHangNumber=totalLieNumber; //总的元素的个数(784)
xunlianjigeshu=oversuoyin-suoyin; //每个线程训练的个数
this.countDownLatch=countDownLatch;
this.manythread=manythread;
trainResultNumber=totalResultNumber;
getnumber = new int[totalResultNumber]; //存储每个数字的个数
everyNumberGeshu=new int[totalResultNumber][totalLieNumber]; //存储每个元素的个数
gailvResultNumber = new double[totalResultNumber]; //存储每个数字的概率
gailvEveryHangLie = new double[totalResultNumber][totalLieNumber]; //存储每个元素的概率
}
@Override
public void run() {
compute(startindex); //计算每个数字的次数
computeGaiLVEveryHang(); //计算每个元素的概率
manythread.updataAllData(gailvResultNumber,gailvEveryHangLie); //计算概率完成之后更新所有线程计算出现的概率
countDownLatch.countDown(); //表示该线程已经进行执行完成
System.out.println("线程我已经完成计算工作!!");
}
/*
* 计算数据
*/
private void compute(int currentsuoyin) {
int resultNumber=0;
int suoyinlie=0;
int value=0;
for(int i=currentsuoyin;i<overindex;i++){
resultNumber=Integer.parseInt(resultData.getString(i, 0));
getnumber[resultNumber]=getnumber[resultNumber]+1;
while(suoyinlie<784){
value=Integer.parseInt(trainData.getString(i, suoyinlie));
if(value>=128){
addEveryLieNumber(resultNumber, suoyinlie); //主要是为了让数据中只有0和1这样的灰度数据方便计算
} //而且对于存在1的时候,才进行存储,也就是表示实际有像素点被画
suoyinlie++;
}
suoyinlie=0; //处理一个后,记得还原
}
}
/*
* 计算每一行中的像素为1的个数
*/
private void addEveryLieNumber(int resultNumber, int suoyinlie) {
everyNumberGeshu[resultNumber][suoyinlie]+=1; //将数值为1的索引个数加1,(784中)
}
/*
* 计算每一个结果每一行中的概率
*/
public void computeGaiLVEveryHang(){
double tempResult=0; //每个行的中间结果(784个)
double tempResultgeshu=0; //得到每个数字出现的个数
double styleResult=0; //保存每个数字的概率中间变量而已
double everyResult=0; //保存每个元素的个数,中间变量而已
for(int i=0;i<trainResultNumber;i++){
tempResultgeshu=getnumber[i]; //得到训练集中,对应数字的个数
if(tempResultgeshu==0){ //防止一个都没出现的情况,为了效果更加的平滑
tempResultgeshu=1;
}
styleResult=tempResultgeshu/xunlianjigeshu; //得到每个数字在训练集中出现的概率
gailvResultNumber[i]=styleResult;
for(int j=0;j<totalHangNumber;j++){
if(everyNumberGeshu[i][j]==0){ //表示一个样本都没出现,则加1,防止出现平滑处理
everyNumberGeshu[i][j]=1;
}
everyResult=everyNumberGeshu[i][j];
tempResult=everyResult/tempResultgeshu; //得到对应的概率
gailvEveryHangLie[i][j]=tempResult; //得到每个元素的概率
}
}
}
}
4: 训练数据及其测试数据的主类
package beiyesifenleiqi;
import java.io.File;
import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.Arrays;
import java.util.Date;
import java.util.concurrent.CountDownLatch;
import javax.xml.crypto.Data;
import shenjingwangluo2.CSVFileUtil;
/*
* 进行测试
*
*/
public class Text {
private static Text text;
private static BeiYeSIFenLeiQi beiyesi;
private static File writePath;
public static void main(String[] args) throws Exception {
Date data=new Date();
SimpleDateFormat si=new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
System.out.println("开始训练时间:"+si.format(data));
text = new Text();
beiyesi = new BeiYeSIFenLeiQi(10, 784);
writePath = new File("D:/xunlianresult");
if(writePath.exists()){ //如果指定的文件已经存在,表示之前已经写入了数据
beiyesi.loadPreviousData(writePath);
}
else{ //没有训练结果,则进行训练
writePath.createNewFile();
dataInit(beiyesi); //训练数据的处理
System.out.println("训练数据并存储数据成功!!!");
}
computeTextData(beiyesi); //进行测试
}
/*
* 进行测试
*/
private static void computeTextData(BeiYeSIFenLeiQi beiyesi) throws Exception {
//输入测试数据
CSVFileUtil util3 = new CSVFileUtil("D:\\textdata.csv");
int textthang=util3.getRowNum(); //得到测试数据行数
CSVFileUtil util4 = new CSVFileUtil("D:\\textresult.csv");
//二值化进行处理,并且进行预测
int getTextNumber=0; //保存测试数据当前预测的值
int yuceresult=0; //预测的结果
int accurencynumber=0; //预测正确的个数
double[] binary=new double[784]; //二值化的信息
for(int i=0;i<textthang;i++){
getTextNumber=Integer.parseInt(util4.getString(i, 0)); //得到测试数据当前预测的值
int currentsuoyin=0;
int value=0;
while(currentsuoyin<784){
value=Integer.parseInt(util3.getString(i, currentsuoyin));
if(value>128){ //因为二值化后大于128的就为1
binary[currentsuoyin]=1;
}
currentsuoyin++;
}
currentsuoyin=0;
//进行预测结果
yuceresult=beiyesi.computeYuCeResult(binary); //得到预测的结果
System.out.print("经过预测的值为:"+yuceresult); //打印预测的结果
if(yuceresult==getTextNumber){ //比较预测和真实结果,是否相同
accurencynumber++;
System.out.println("(正确)");
}
else{
System.out.println("(错误),实际的数字为:"+getTextNumber);
}
Arrays.fill(binary, 0); //记得每次要把数组清零否则会影响后面的内容
}
//打印准确度
double result=(Double.valueOf(accurencynumber)/Double.valueOf(textthang))*100;
System.out.println("精确度为:"+result+"%");
Date data=new Date();
SimpleDateFormat si=new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
System.out.println("测试结束时间:"+si.format(data));
}
/*
* 训练数据的初始化处理
*/
private static void dataInit(BeiYeSIFenLeiQi beiyesi) throws Exception {
//得到训练集的结果标签
CSVFileUtil util = new CSVFileUtil("D:\\trainresult.csv");
int resulthang=util.getRowNum(); //得到训练结果行数
int resultlie=util.getColNum(); //得到训练结果列数
//得到训练数据的结果
CSVFileUtil util2 = new CSVFileUtil("D:\\traindata.csv");
int inputhang=util2.getRowNum(); //得到训练结果行数
int inputlie=util2.getColNum(); //得到训练结果列数
beiyesi.setXunLianNumber(resulthang);
readData(util,util2); //初始化数据(用多线程进行处理)
//下面这些是最开始写的时候没用多线程进行处理的方法,也可以,就是训练太慢了,自己又改进了s
// int resultNumber=0;
// int suoyinlie = 0;
// int value=0;
// for(int i=0;i<resulthang;i++){ //进行需要处理数据的个数的统一
// resultNumber=Integer.parseInt(util.getString(i, 0)); //得到结果标签的值
// beiyesi.addeveryResultNumber(resultNumber); //对应的个数+1
//
// while(suoyinlie<784){
// value=Integer.parseInt(util2.getString(i, suoyinlie));
// if(value>=255/2){
// beiyesi.addEveryLieNumber(resultNumber, suoyinlie); //主要是为了让数据中只有0和1这样的灰度数据方便计算
// } //而且对于存在1的时候,才进行存储,也就是表示实际有像素点被画
// suoyinlie++;
// }
// suoyinlie=0; //处理一个后,记得还原
// }
//进行每个数字中对应1位置占总数字个数的概率的计算
// beiyesi.computeGaiLVEveryHang();
// beiyesi.printfResult(); //打印概率的结果
}
/*
* 初始化训练数据的内容
*/
private static void readData(CSVFileUtil util, CSVFileUtil util2) throws IOException {
CountDownLatch countDownLatch=new CountDownLatch(10); //开10个线程进行读取数据处理
for(int i=0;i<10;i++){ //开启线程进行统计需要的数量
Thread start=new Thread(new startDealThread(text,countDownLatch,i*250,(i+1)*250,util,util2,10,784));
start.start();
}
try {
countDownLatch.await(); //等待所有的子线程全部执行完成,才执行后面的任务
//beiyesi.computeGaiLVEveryHang(); //所有数字的个数都记录好之后,进行计算概率
beiyesi.writeEveryGailv(writePath); //把训练好的结果存放到TXT文件,方便下次直接读取
} catch (InterruptedException e) {
e.printStackTrace();
}
}
/*
* 更新线程计算完后每个数字和元素的概率
*/
public synchronized void updataAllData(double[] gailvResult, double[][] gaileveryElement) {
beiyesi.updataThreadComputeGaiLv(gailvResult, gaileveryElement); //更新概率
// int currentnumber=0;
// for(int i=0;i<10;i++){
// beiyesi.addeveryResultNumber(getnumber[i],i);
// for(int liesuoyin=0;liesuoyin<784;liesuoyin++){
// currentnumber=everyNumberGeshu[i][liesuoyin];
// beiyesi.addEveryLieNumber(i, liesuoyin,currentnumber);
// }
// }
}
}
上面的代码没有进行太多的优化,所以可能存在一些多余的部分,而且用了线程也是为了方便读取数据而已。。
通过上面的代码的话,最后达到的准确度有94%左右,所以这个基本还行吧。至少能进行有效的识别了。。。。。。。。。。。…………
注意:: 下面还说说,自己在做这个实验过程中,遇到的一些关于朴素贝叶斯进行分类的问题吧!(进行贴图总结了)!!!
反应出的问题(也就是缺点):
好了,这个就讲解到这里了。。。共同进步,慢慢的学习!!!!!!!!!!!!!!