最近开始读weka部分算法源码
NaiveBayesSimple是最简单的Bayes实现,对后验概率采用的方法是似然估计。核心函数 是在buldClassifier上
public void buildClassifier(Instances instances) throws Exception {
int attIndex = 0;
double sum;
// can classifier handle the data?
getCapabilities().testWithFail(instances);//测试一下数据能不能用来训练bayes模型,对于连续的值,如果两个值差距大于0.1,weka就会看成不同的属性值
// remove instances with missing class
instances = new Instances(instances);
instances.deleteWithMissingClass();//去除没有类标签的数据
m_Instances = new Instances(instances, 0);
// Reserve space
m_Counts = new double[instances.numClasses()]
[instances.numAttributes() - 1][0];//某个类某个属性的出现总的次数
m_Means = new double[instances.numClasses()]
[instances.numAttributes() - 1];
m_Devs = new double[instances.numClasses()]//
[instances.numAttributes() - 1];
m_Priors = new double[instances.numClasses()];//每个类的先验概率
Enumeration enu = instances.enumerateAttributes();
while (enu.hasMoreElements()) {
Attribute attribute = (Attribute) enu.nextElement();
if (attribute.isNominal()) {//是离散的值的数据的话,就要用三维的来表示
for (int j = 0; j < instances.numClasses(); j++) {
m_Counts[j][attIndex] = new double[attribute.numValues()];
}
} else {
for (int j = 0; j < instances.numClasses(); j++) {
m_Counts[j][attIndex] = new double[1];
}
}
attIndex++;
}
// Compute counts and sums
Enumeration enumInsts = instances.enumerateInstances();
while (enumInsts.hasMoreElements()) {
Instance instance = (Instance) enumInsts.nextElement();
if (!instance.classIsMissing()) {
Enumeration enumAtts = instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
if (!instance.isMissing(attribute)) {
if (attribute.isNominal()) {
m_Counts[(int)instance.classValue()][attIndex]
[(int)instance.value(attribute)]++;
} else {
m_Means[(int)instance.classValue()][attIndex] +=
instance.value(attribute);
m_Counts[(int)instance.classValue()][attIndex][0]++;
}
}
attIndex++;
}
m_Priors[(int)instance.classValue()]++;
}
}
// Compute means
Enumeration enumAtts = instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
if (attribute.isNumeric()) {
for (int j = 0; j < instances.numClasses(); j++) {
if (m_Counts[j][attIndex][0] < 2) {//一个类的属性取值至少要两个值,不然该属性没有用来分类的意义
throw new Exception("attribute " + attribute.name() +
": less than two values for class " +
instances.classAttribute().value(j));
}
m_Means[j][attIndex] /= m_Counts[j][attIndex][0];//计算后验概率
}
}
attIndex++;
}
// Compute standard deviations
enumInsts = instances.enumerateInstances();
while (enumInsts.hasMoreElements()) {
Instance instance =
(Instance) enumInsts.nextElement();
if (!instance.classIsMissing()) {
enumAtts = instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
if (!instance.isMissing(attribute)) {
if (attribute.isNumeric()) {
m_Devs[(int)instance.classValue()][attIndex] +=
(m_Means[(int)instance.classValue()][attIndex]-
instance.value(attribute))*
(m_Means[(int)instance.classValue()][attIndex]-
instance.value(attribute));
}
}
attIndex++;
}
}
}
enumAtts = instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
if (attribute.isNumeric()) {
for (int j = 0; j < instances.numClasses(); j++) {
if (m_Devs[j][attIndex] <= 0) {
throw new Exception("attribute " + attribute.name() +
": standard deviation is 0 for class " +
instances.classAttribute().value(j));
}
else {
m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1;
m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]);
}
}
}
attIndex++;
}
// Normalize counts
enumAtts = instances.enumerateAttributes();
attIndex = 0;
while (enumAtts.hasMoreElements()) {
Attribute attribute = (Attribute) enumAtts.nextElement();
if (attribute.isNominal()) {
for (int j = 0; j < instances.numClasses(); j++) {
sum = Utils.sum(m_Counts[j][attIndex]);
for (int i = 0; i < attribute.numValues(); i++) {
m_Counts[j][attIndex][i] =
(m_Counts[j][attIndex][i] + 1) //拉普拉斯平滑
/ (sum + (double)attribute.numValues());
}
}
}
attIndex++;
}
// Normalize priors
sum = Utils.sum(m_Priors);//类别概率,同样也是拉普拉斯平滑
for (int j = 0; j < instances.numClasses(); j++)
m_Priors[j] = (m_Priors[j] + 1)
/ (sum + (double)instances.numClasses());
}