机器学习实战朴素贝叶斯的java实现

package com.haolidong.Bayes;

import java.util.ArrayList;

/**
 * 
 * @author haolidong
 * @Description: [该类主要用于保存特征信息]
 * @parameter data: [主要保存特征矩阵]
 */
public class Matrix {
	public ArrayList<ArrayList<String>> data;

	public Matrix() {
		// TODO Auto-generated constructor stub
		data = new ArrayList<ArrayList<String>>();
	}
}
package com.haolidong.Bayes;

import java.util.ArrayList;

/**
 * 
 * @author haolidong
 * @Description: [该类主要用于保存特征信息以及标签值]
 * @parameter labels: [主要保存标签值]
 */
public class CreateDataSet extends Matrix {
	public ArrayList<String> labels;

	public CreateDataSet() {
		// TODO Auto-generated constructor stub
		super();
		labels = new ArrayList<String>();
	}

	/**
	 * @author haolidong
	 * @Description: [机器学习实战决策树第一个案例的数据]
	 */
	public void initTest() {
		ArrayList<String> ab1 = new ArrayList<String>();
		ArrayList<String> ab2 = new ArrayList<String>();
		ArrayList<String> ab3 = new ArrayList<String>();
		ArrayList<String> ab4 = new ArrayList<String>();
		ArrayList<String> ab5 = new ArrayList<String>();
		ArrayList<String> ab6 = new ArrayList<String>();
		ab1.add("my");
		ab1.add("dog");
		ab1.add("has");
		ab1.add("flea");
		ab1.add("problems");
		ab1.add("help");
		ab1.add("please");
		ab2.add("maybe");
		ab2.add("not");
		ab2.add("take");
		ab2.add("him");
		ab2.add("to");
		ab2.add("dog");
		ab2.add("park");
		ab2.add("stupid");
		ab3.add("my");
		ab3.add("dalmation");
		ab3.add("is");
		ab3.add("so");
		ab3.add("cute");
		ab3.add("I");
		ab3.add("love");
		ab3.add("him");
		ab4.add("stop");
		ab4.add("posting");
		ab4.add("stupid");
		ab4.add("worthless");
		ab4.add("garbage");
		ab5.add("mr");
		ab5.add("licks");
		ab5.add("ate");
		ab5.add("my");
		ab5.add("steak");
		ab5.add("how");
		ab5.add("to");
		ab5.add("stop");
		ab5.add("him");
		ab6.add("quit");
		ab6.add("buying");
		ab6.add("worthless");
		ab6.add("dog");
		ab6.add("food");
		ab6.add("stupid");
		data.add(ab1);
		data.add(ab2);
		data.add(ab3);
		data.add(ab4);
		data.add(ab5);
		data.add(ab6);

		labels.add("0");
		labels.add("1");
		labels.add("0");
		labels.add("1");
		labels.add("0");
		labels.add("1");
	}
}
package com.haolidong.Bayes;

import java.util.ArrayList;
/**
 * 
 * @parameter p0Vect 类别0的特征向量(概率向量)
 * @parameter p1Vect 类别1的特征向量(概率向量)
 * @parameter pAbusive 正样本(为1的样本)的比例
 * @author haolidong  
 * @Description: [该类主要用于保存特征信息]
 * @parameter data: [主要保存特征矩阵]
 */
public class TrainNB0DataSet {
	public ArrayList<Double> p0Vect;
	public ArrayList<Double> p1Vect;
	public double pAbusive;

	public TrainNB0DataSet() {
		p0Vect = new ArrayList<Double>();
		p1Vect = new ArrayList<Double>();
		pAbusive = 0.0;
	}
}

package com.haolidong.Bayes;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;

public class Bayes {
	public static void main(String[] args) {
		spamTest();
	}
	/**
	 * @param end  从0到end的范围中产生num个不重复的随机数
	 * @param num  num个随机数
	 * @return 返回产生的n个随机数
	 * @author haolidong
	 * @Description: [从0到end的范围中产生num个不重复的随机数]
	 */
	public static HashSet<Integer> randomdif(int end,int num){
		HashSet<Integer> rndint = new HashSet<Integer>();
		rndint.size();
		while ( rndint.size() < num ) {
			rndint.add((int) (Math.random()*end));
		}
		return rndint;
	}
	/**
	 * @author haolidong
	 * @Description: [垃圾邮件分类测试]
	 */
	public static void spamTest(){
		ArrayList<String> fullText = new ArrayList<String>();
		CreateDataSet DataSet = new CreateDataSet();
		for (int i = 1; i < 26; i++) {
			ArrayList<String> hamWordList = new ArrayList<String>();
			ArrayList<String> spamWordList = new ArrayList<String>();
			String hamPath = new String("I:\\machinelearninginaction\\Ch04\\email\\ham\\"+i+".txt");
			String spamPath = new String("I:\\machinelearninginaction\\Ch04\\email\\spam\\"+i+".txt");
			hamWordList = textParse(spamPath, 2);
			DataSet.data.add(hamWordList);
			DataSet.labels.add("1");
			for (int j = 0; j < hamWordList.size(); j++) {
				fullText.add(hamWordList.get(j));
			}
			spamWordList=textParse(hamPath, 2);
			DataSet.data.add(spamWordList);
			DataSet.labels.add("0");
			for (int j = 0; j < spamWordList.size(); j++) {
				fullText.add(spamWordList.get(j));
			}
		}
		//获取词典
		HashSet<String> vocabList = new HashSet<String>();
		vocabList = createVocabList(DataSet);
		HashSet<Integer> rndint = new HashSet<Integer>();
		//随机产生10个测试集,其余的为训练集
		rndint = randomdif(50,10);
		Matrix testMatrix = new Matrix();
		Matrix trainMatrix = new Matrix();
		ArrayList<String> trainLabels = new ArrayList<String>();
		ArrayList<String> testLabels = new ArrayList<String>();
		Matrix testMatrixTrans = new Matrix();
		Matrix trainMatrixTrans = new Matrix();
		for(Integer i:rndint){
			testMatrix.data.add(DataSet.data.get(i));
			testLabels.add(DataSet.labels.get(i));
		}
		for (int i = 0; i < DataSet.data.size(); i++) {
			if(!rndint.contains(i)){
				trainMatrix.data.add(DataSet.data.get(i));
				trainLabels.add(DataSet.labels.get(i));
			}
		}
		//转化到0 1矩阵
		for (int i = 0; i < trainMatrix.data.size(); i++) {
			trainMatrixTrans.data.add(setOfWords2Vec(vocabList,trainMatrix.data.get(i)));
		}
		for (int i = 0; i < testMatrix.data.size(); i++) {
			testMatrixTrans.data.add(setOfWords2Vec(vocabList,testMatrix.data.get(i)));
		}
		//训练集的训练
		TrainNB0DataSet td = new TrainNB0DataSet();
		td = trainNB0(trainMatrixTrans,trainLabels);
		//对测试集进行测试
		int errorCount=0;
		for (int i = 0; i < testMatrixTrans.data.size(); i++) {
			int num=classifyNB(testMatrixTrans.data.get(i), td.p0Vect, td.p1Vect, td.pAbusive);
			System.out.println("the predict:"+num+" , the real:"+testLabels.get(i));
			if(num!=Integer.parseInt(testLabels.get(i))){
				errorCount++;
			}
		}
		System.out.println("the errorRate is:"+1.0*errorCount/testMatrixTrans.data.size());
	}
	public static ArrayList<String> textParse(String fileName,int moreThan){
		ArrayList<String> strSplitList = new ArrayList<String>();
		String s = readFile(fileName);
		strSplitList = extractStrlist(s,moreThan);
		return strSplitList;
		
	}
	/**
	 * @param fileName  输入的完整文件路径
	 * @return 所有的文件内容的字符串
	 * @author haolidong
	 * @Description: [一行一行读取文件,然后用字符串全部串起来返回,每一行之间使用空格分割]
	 */
	public static String readFile(String fileName) {
		File file = new File(fileName);
		BufferedReader reader = null;
		String s = new String();
		try {
			reader = new BufferedReader(new FileReader(file));
			String tempString = null;
			// 一次读入一行,直到读入null为文件结束
			while ((tempString = reader.readLine()) != null) {
				//加上" "是为了和下面一段的字符进行区分
				s=s+tempString+" ";
			}
			reader.close();
		} catch (IOException e) {
			e.printStackTrace();
		} finally {
			if (reader != null) {
				try {
					reader.close();
				} catch (IOException e1) {
				}
			}
		}
		return s;
	}


	/**
	 * @param inputString 输入的字符串
	 * @param moreThan    只有超过moreThan的字符串才会被保留
	 * @return    分割好的数据串
	 * @author haolidong
	 * @Description: [读取一个字符串,进行分割,去掉除了字母数字以外的字符数组,而且所有的字符都改成小写]
	 */
	public static ArrayList<String> extractStrlist(String inputString,int moreThan) {
		ArrayList<String> strSplitList = new ArrayList<String>();
		String regEx = "\\W*";
		String sentence="";
//		String inputString = "This book is the best book on M.L. I have";
		String[] predel = inputString.split(regEx);
		for (int i = 0; i < predel.length; i++) {
			if(predel[i].equals(""))
				sentence+=" ";
			else
				sentence+=predel[i];
		}
		String[] strSplit=sentence.split(" ");
		for (int i = 0; i < strSplit.length; i++) {
			if(strSplit[i].length()>moreThan) {
				strSplitList.add(strSplit[i].toLowerCase());
			}		
		}
		return strSplitList;
	}


	/**
	 * @param vec2Classify   需要进行分类的向量
	 * @param p0Vec          类别0的权值向量
	 * @param p1Vec          类别1的权值向量
	 * @param pClass1                            类别1所占的比重
	 * @return               返回最后的分类结果
	 * @author haolidong     
	 * @Description: [计算在每一类中最后的概率返回最大的所对应的标签]
	 */
	public static int classifyNB(ArrayList<String> vec2Classify, ArrayList<Double> p0Vec, ArrayList<Double> p1Vec,
			double pClass1) {
		double p1 = 0.0;
		double p0 = 0.0;
		for (int i = 0; i < vec2Classify.size(); i++) {
			p1 = p1 + Double.parseDouble(vec2Classify.get(i)) * p1Vec.get(i);
			p0 = p0 + Double.parseDouble(vec2Classify.get(i)) * p0Vec.get(i);
		}
		p1 = p1 + Math.log(pClass1);
		p0 = p0 + Math.log(1 - pClass1);
		if (p1 > p0)
			return 1;
		else
			return 0;
	}

	/**
	 * @param trainMatrix      训练矩阵
	 * @param trainCategory    训练目录标签
	 * @return                 返回最后训练结果,包括每一类的特征矩阵以及每一类的比重情况
	 * @author haolidong     
	 * @Description: [贝叶斯分类的重点函数,数据集的训练,返回特征矩阵和向量]
	 */
	public static TrainNB0DataSet trainNB0(Matrix trainMatrix, ArrayList<String> trainCategory) {
		int numTrainDocs = trainMatrix.data.size();
		int numWords = trainMatrix.data.get(0).size();
		TrainNB0DataSet resultSet = new TrainNB0DataSet();
		ArrayList<Double> p0Num = new ArrayList<Double>();
		ArrayList<Double> p1Num = new ArrayList<Double>();
		double trainCategorySum = 0.0;
		for (int i = 0; i < trainCategory.size(); i++) {
			trainCategorySum = trainCategorySum + Double.parseDouble(trainCategory.get(i));
		}
		resultSet.pAbusive = trainCategorySum / numTrainDocs;
		for (int i = 0; i < numWords; i++) {
			p0Num.add(1.0);
			p1Num.add(1.0);
		}
		double p0Denom = 2.0;
		double p1Denom = 2.0;
		for (int i = 0; i < numTrainDocs; i++) {
			if (trainCategory.get(i).equals("1")) {
				for (int j = 0; j < numWords; j++) {
					p1Num.set(j, p1Num.get(j) + Double.parseDouble(trainMatrix.data.get(i).get(j)));
				}
			} else {
				for (int j = 0; j < numWords; j++) {
					p0Num.set(j, p0Num.get(j) + Double.parseDouble(trainMatrix.data.get(i).get(j)));
				}
			}

		}
		for (int i = 0; i < numWords; i++) {
			p0Denom += p0Num.get(i);
			p1Denom += p1Num.get(i);
		}
		p0Denom = p0Denom - numWords;
		p1Denom = p1Denom - numWords;
		for (int i = 0; i < numWords; i++) {
			resultSet.p0Vect.add(Math.log(p0Num.get(i) / p0Denom));
			resultSet.p1Vect.add(Math.log(p1Num.get(i) / p1Denom));
		}

		return resultSet;
	}

	/**
	 * @param vocabSet       字典
	 * @param inputSet       输入数据集
	 * @return               返回与字典一一对应的数据集
	 * @author haolidong     
	 * @Description: [生成一个全部为0的字典,把字典中数据集中有的字符串设置为1,其他的设置为0,返回设置完的字典]
	 */
	public static ArrayList<String> setOfWords2Vec(HashSet<String> vocabSet, ArrayList<String> inputSet) {
		ArrayList<String> returnVec = new ArrayList<String>();
		boolean flag;
		for (String value : vocabSet) {
			flag = false;
			for (int i = 0; i < inputSet.size(); i++) {
				if (inputSet.get(i).equals(value)) {
					returnVec.add("1");
					flag = true;
					break;
				}
			}
			if (flag == false) {
				returnVec.add("0");
			}
		}
		return returnVec;
	}

	/**
	 * @param dataSet    输入数据集
	 * @return           字典
	 * @author haolidong     
	 * @Description: [输入数据集,数据有比较大的重复,然后去掉重复的数据,最后生成字典]
	 */
	public static HashSet<String> createVocabList(Matrix dataSet) {
		HashSet<String> vocabSet = new HashSet<String>();
		for (int i = 0; i < dataSet.data.size(); i++) {
			for (int j = 0; j < dataSet.data.get(i).size(); j++) {
				vocabSet.add(dataSet.data.get(i).get(j));
			}
		}
		return vocabSet;

	}

	/**
	 * @author haolidong     
	 * @Description: [对于生成字典功能的测试]
	 */
	public static void testVocabList() {
		CreateDataSet dataSet = new CreateDataSet();
		dataSet.initTest();
		HashSet<String> vocabSet = new HashSet<String>();
		vocabSet = createVocabList(dataSet);
		System.out.println(vocabSet);
	}

	/**
	 * @author haolidong     
	 * @Description: [对于输入字符集转化成字典的测试]
	 */
	public static void testWord2Vec() {
		CreateDataSet dataSet = new CreateDataSet();
		dataSet.initTest();
		HashSet<String> vocabSet = new HashSet<String>();
		ArrayList<String> returnVec = new ArrayList<String>();
		vocabSet = createVocabList(dataSet);
		returnVec = setOfWords2Vec(vocabSet, dataSet.data.get(0));
		System.out.println(returnVec);
	}

	/**
	 * @author haolidong     
	 * @Description: [对于样本训练的测试]
	 */
	public static void testTrain() {
		CreateDataSet dataSet = new CreateDataSet();
		Matrix trainMatrix = new Matrix();
		dataSet.initTest();
		HashSet<String> vocabSet = new HashSet<String>();
		vocabSet = createVocabList(dataSet);
		for (int i = 0; i < dataSet.data.size(); i++) {
			trainMatrix.data.add(setOfWords2Vec(vocabSet, dataSet.data.get(i)));
		}
		trainNB0(trainMatrix, dataSet.labels);
	}
	/**
	 * @author haolidong     
	 * @Description: [对于样本分类的测试]
	 */
	public static void testingNB() {
		CreateDataSet dataSet = new CreateDataSet();
		TrainNB0DataSet td = new TrainNB0DataSet();
		ArrayList<String> testEntry = new ArrayList<String>();
		Matrix trainMatrix = new Matrix();
		dataSet.initTest();
		HashSet<String> vocabSet = new HashSet<String>();
		vocabSet = createVocabList(dataSet);
		for (int i = 0; i < dataSet.data.size(); i++) {
			trainMatrix.data.add(setOfWords2Vec(vocabSet, dataSet.data.get(i)));
		}
		td = trainNB0(trainMatrix, dataSet.labels);
		testEntry.add("love");
		testEntry.add("my");
		testEntry.add("dalmation");
		testEntry = setOfWords2Vec(vocabSet, testEntry);
		System.out.println("classified as:"+classifyNB(testEntry,td.p0Vect,td.p1Vect,td.pAbusive));
		testEntry.clear();
		testEntry.add("stupid");
		testEntry.add("garbage");
		testEntry = setOfWords2Vec(vocabSet, testEntry);
		System.out.println("classified as:"+classifyNB(testEntry,td.p0Vect,td.p1Vect,td.pAbusive));
	}
}




  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值