Bayes分类算法
简介
- 概率论的公式
- 一个小例子
- 算法的思想
- 呈上代码
贝叶斯公式的简介
在这里p(x | y)表示在y事件发生时,x事件发生的概率。
一个小例子
Name | Gender | Height | Class |
---|---|---|---|
张三 | F | 1.68 | Medium |
李四 | M | 1.0 | Short |
王五 | M | 1.9 | Tall |
赵六 | M | 1.2 | Short |
分类算法的目的在于给出了以上面的一些例子作为训练集,按Class将每一个条目分类,训练集里的条目是分好类的,我们根据它训练出一个模型(公式),当有给定的(Name,Gender,Height)时,我们就马上可以输出他的Class结果。
例如现在我们要预测t=(陈七,F,1.67)是属于那个分类的。
为了方便,我们先把身高做一个划分,把身高分为(0,1.6], (1.6,1.7], (1.7,1.8], (1.8,1.9],(1.9,2.0],(2.0,…] 总共6个区间。
分别属于矮,中,高。
那我们现在就要预测t属于哪一类,既要计算:
p(矮 | t) p(中 | t) p(高 | t)的大小,最大的那个就是我们所要选择的。
- 根据贝叶斯公式,为了求p(矮 | t)我们先得求p(t)
- 求p(矮)
- 求p(t | 矮)
- 上述三个都可以根据已经给出的数据求出
- p(中 | t) p(高 | t)的求法类似p(矮 | t)
算法思想
上面的属性有性别(男,女),身高(数值区间),目标属性有分类的(矮,中,高)
例如上面的第一行表示性别为男的条目分类中矮,中,高的条目数分别为1,2,3.
- 下面的算法中Attribute类表示各个属性(性别,身高),它的值域为性别(男,女),目标属性值为:矮,中,高。这个类维护一个计数表,负责计数表的更新。
- DataSet类:维护目标属性,添加新的目标属性值。
- Bayes类:计算各个分类的概率
package Bayes;
import java.util.ArrayList;
public class Attribute {
public ArrayList<String> range;//属性的值域
public ArrayList<ArrayList<Double>> countMatrix;//计数表,统计各属性值在目标属性上的取值的个数,目标属性就是最终的分类属性。
public String attrName;//属性名
public int attrIndex;//属性在DataSet上的序号
public Attribute(String attrName, int attrIndex)
{
this.attrIndex = attrIndex;
this.attrName = attrName;
this.range = new ArrayList<String>();
this.countMatrix = new ArrayList<ArrayList<Double>>();
}
public void AddData(ArrayList<String> dataRow)
{
String columnValue = dataRow.get(attrIndex);
String targetValue = dataRow.get(DataSet.attr.size());
if(range.contains(columnValue))//如果该属性值存在,则在原来的基础上加1
{
int columValueIndex = range.indexOf(columnValue);
int targetValueIndex = DataSet.targetValueRange.indexOf(targetValue);
ArrayList<Double> matrixRow = countMatrix.get(columValueIndex);
if(targetValueIndex >= matrixRow.size())
{
int targetSize = DataSet.targetValueRange.size();
for(int i = 0; i < (targetSize - matrixRow.size()); i++)//有新的目标属性值,将原来缺失的补齐。
{
matrixRow.add(new Double(0));
}
matrixRow.set(targetValueIndex, matrixRow.get(targetValueIndex)+1);
// System.out.println("add");
}
}
else//若属性值不存在,则得将属性值加进去。
{
this.range.add(columnValue);
int targetValueIndex = DataSet.targetValueRange.indexOf(targetValue);
ArrayList<Double> matrixRow = new ArrayList<Double>();
for(int i = 0; i < DataSet.targetValueRange.size(); i++)//该属性不存在,则为它构建一行新的
{
matrixRow.add(new Double(0));
}
matrixRow.set(targetValueIndex, new Double(1));
this.countMatrix.add(matrixRow);
// System.out.println(matrixRow.get(0));
}
}
}
package Bayes;
import java.util.ArrayList;
import Bayes.Attribute;
public class DataSet {
public static ArrayList<Attribute> attr;//属性集
public String targetAttribute;//目标属性名
public static ArrayList<String> targetValueRange;//目标属性的值域
public static ArrayList<Double> targetValueCount;//目标属性各值出现的次数
/**
* 数据集初始化,输入一个属性集和一个目标属性名
* @param attrSet //属性集
* @param targetAttrbute //目标属性
*/
public DataSet(ArrayList<String> attrSet, String targetAttribute)
{
DataSet.attr = new ArrayList<Attribute>();
for(int i = 0; i < attrSet.size(); i++)
{
DataSet.attr.add(new Attribute(attrSet.get(i),i));
}
this.targetAttribute = targetAttribute;
targetValueCount = new ArrayList<Double>();
this.targetValueRange = new ArrayList<String>();
}
public void addRow(String... datas)
{
ArrayList<String> row = new ArrayList<String>();
for(String str : datas)
{
row.add(str);
}
String targetValue = row.get(DataSet.attr.size());
if(targetValueRange.contains(targetValue))
{
int targetIndex = this.targetValueRange.indexOf(targetValue);
targetValueCount.set(targetIndex, targetValueCount.get(targetIndex) + 1);
}
else
{
targetValueRange.add(targetValue);
targetValueCount.add(1.0);
}
for(int i = 0; i < attr.size(); i++)//更新计数表
{
Attribute att = DataSet.attr.get(i);
att.AddData(row);
}
System.out.println(targetValueRange.size());
}
}
package Bayes;
public class Bayes {
public double[] Test(String... dataRow)
{
//存放各个目标属性的似然值
double[] likelihood = new double[DataSet.targetValueRange.size()];
//计算dataRow相对于各个目标属性的似然值,即:P(dataRow|targetValue)*P(targetValue)
for(int i = 0; i < DataSet.targetValueRange.size(); i++)
{
String targetValue = DataSet.targetValueRange.get(i);
double probOfTarget = getProb(targetValue);
Double probOfData = null;
for(int j = 0; j < DataSet.attr.size(); j++)
{
Attribute attr = DataSet.attr.get(j);
double tempProb = getProb(attr, dataRow[j], targetValue);
if(probOfData == null)
{
probOfData = tempProb;
}
else
{
probOfData *= tempProb;
}
}
likelihood[i] = probOfTarget * probOfData;
}
double sumlikelihood = 0.0;
for(int i = 0; i < likelihood.length; i++)
{
sumlikelihood += likelihood[i];
}
double[] result = new double[DataSet.targetValueRange.size()];
for(int i = 0; i < result.length; i++)
{
result[i] = sumlikelihood == 0.0 ? 0 : (likelihood[i]/sumlikelihood);
}
return result;
}
/**
* 计算P(targetValue)的值
* @param targetValue //指定要计算的分类的值
* @return 概率
*/
private double getProb(String targetValue)
{
double sum = 0.0;//总的目标属性值的次数
double valueCount = 0.0;//标记指定目标属性出现的次数
for(int i = 0; i < DataSet.targetValueRange.size(); i++)
{
String value = DataSet.targetValueRange.get(i);
double count = DataSet.targetValueCount.get(i);
sum += count;
if(targetValue.equals(value))
valueCount = count;
}
return sum < 1 ? 0 : (valueCount / sum);
}
/**
* 获取p(attrValue/targetValue)
* @param attr 属性
* @param attrValue 属性值
* @param targetValue 目标属性
* @return
*/
private double getProb(Attribute attr, String attrValue, String targetValue)
{
double sum = 0.0;
double count = 0.0;
int columnValueIndex = DataSet.targetValueRange.indexOf(targetValue);
int attrIndex = attr.range.indexOf(attrValue);
for(int i = 0; i < attr.range.size(); i++)
{
double tempCount = columnValueIndex >= attr.countMatrix.get(i).size()?0:attr.countMatrix.get(i).get(columnValueIndex);
sum += tempCount;
if(attrIndex == i)
count = tempCount;
}
return sum < 1 ? 0 : (count/sum);
}
}
package Bayes;
import java.util.ArrayList;
public class main {
public static void main(String[] args) {
// TODO Auto-generated method stub
ArrayList<String> attr = new ArrayList<String>();
attr.add("Gender");
attr.add("Height");
DataSet dataSet = new DataSet(attr, "Class");
//添加数据
dataSet.addRow("F", "1", "Short");
dataSet.addRow("F", "1.5", "Medium");
dataSet.addRow("M", "1.8", "Tall");
Bayes bayes = new Bayes();
double[] result = bayes.Test("M", "1.8");
for(int i = 0; i < result.length; i++)
{
String targetValue = DataSet.targetValueRange.get(i);
System.out.println("P("+targetValue+"): " + result[i]) ;
}
}
}