/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*/
/*
* TAN_W_2_0.java
* NB. This is the TAN of Friedman, Geiger, and Goldszmidt.
* "Bayesian Network Classifiers",
* Machine Learning, Vol. 29, 131-163, 1997.
*
* Author: Zishuo Wang
* Version: 2.0
*
*
*/
package weka.classifiers.bayes;
import java.util.*;
import weka.core.*;
import weka.core.Capabilities.Capability;
import weka.classifiers.*;
/**
* Class for building and using a Tree Augmented Naive Bayes(TAN) classifier.
* This method outperforms naive Bayes, yet at the same time maintains the
* computational simplicity (no search involved) and robustness that
* characterize naive Bayes. For more information, see
* <p>
* Friedman,N.,Geiger,D. & Goldszmidt,M. (1997). Bayesian Network Classifiers
* Published in Machine Learning(Vol.29,pp.131-163).
*
* Valid options are:
* <p>
*
* -E num <br>
* The estimation strategies of probabilities. Valid values are:
* 0 For conditional probabilities, using M-estimation plus LaPlace estimation,
* otherwise only using LaPlace estimation.
* 1 For any probability, using Laplace estimation.
* 2 For conditional probabilities, only using M-estimation,
* otherwise only using LaPlace estimation.
* 3 If any probability nearly equals 0, using the constant EPSLON instead.
* (default: 0).
* <p>
*
* -M <br>
* If set, delete all the instances with any missing value (default: false).
* <p>
*
* -R <br>
* Choose the root node for building the maximum spanning tree
* (default: set by random).
* <p>
*
* @author Zhihai Wang (zhhwang@bjtu.edu.cn)
* @version $Revision: 3.1.0 $
*/
/**
* @author WangZiShuo
*
*/
public class TAN_W_2_0 extends Classifier{
/**
*
*/
private static final long serialVersionUID = 6763153202302282131L;
/** The copy of the training instances. */
protected Instances m_Instances;
/** The number of instances in the training instances. */
private double m_NumInstances;
/** The number of trainings with valid class values observed in dataset.*/
private double m_SumInstances = 0;
/** The number of attributes, including the class. */
private int m_NumAttributes;
/** The number of class values. */
protected int m_NumClasses;
/** The index of the class attribute. */
private int m_ClassIndex;
/** The counts for each class value. */
private double[] m_Priors;
/**
* The sums of attribute-class counts.
* m_CondiPriors[c][k] is the same as m_CondiCounts[c][k][k]
*/
private long [][] m_CondiPriors;
/** For m_NumClasses * m_TotalAttValues * m_TotalAttValues. */
private long[][][] m_CondiCounts;
/** The number of values for all attributes, not including class. */
private int m_TotalAttValues;
/** The starting index (in m_CondiCounts matrix) of each attribute. */
private int[] m_StartAttIndex;
/** The number of values for each attribute. */
private int[] m_NumAttValues;
/** Count for P(ai, aj).*/
private int [][] AandB;
/** The Smoothing parameter for M-estimation */
private final double SMOOTHING = 5.0;
/**
* The counts (frequency) of each attribute value for the dataset.
* Here for security, but it can be used for weighting.
*/
//use AandB[ai][ai] to replace m_Frequencies[ai] 14th.August 2008
//private double[] m_Frequencies;
/** The matrix of conditional mutual information */
private double[][] m_CondiMutualInfo;
/** The minimum item in the matrix of conditional mutual information */
private double EPSILON = 1.0E-4;
/** The array to keep track of which attribute has which parent. (Tree) */
private int[] m_Parents;
/**
* The smoothing strategy of estimation. Valid values are: 0 For any
* probability, using M-estimation & Laplace, otherwise using LaPlace
* estimation. 1 For any probability, using Laplace estimation. 2 For any
* probability, using M-estimation, and otherwise using LaPlace estimation.. 3
* If any probability nearly equals 0, using the constant EPSLON instead. (If
* any prior probability nearly equals 0, then throws an Exception.) (default:
* 0).
*/
private int m_Estimation = 0;
/** If set, delete all the instances with any missing value (default: false). */
private boolean m_DelMissing = true;
/**
* Choose the root node for building the maximum spanning tree
* (default: m_Root = -1, i.e., set by random).
*/
private int m_Root = -1;
/** The number of arcs in current traning dataset, only for toString() */
private int m_Arcs = 0;
/** The number of instances with missing values, only for toString(). */
private double m_NumOfMissings = 0;
@Override
public void buildClassifier(Instances data) throws Exception {
// can classifier handle the data?
getCapabilities().testWithFail(data);
// remove instances with missing class
m_Instances = new Instances(data);
m_Instances.deleteWithMissingClass();
m_DelMissing = true;
//设置与计数类变量
m_NumInstances = m_Instances.numInstances();
m_NumAttributes = m_Instances.numAttributes();
m_NumClasses = m_Instances.numClasses();
m_ClassIndex = m_Instances.classIndex();
//
m_StartAttIndex = new int[m_NumAttributes];
m_NumAttValues = new int [m_NumAttributes];
m_CondiMutualInfo = new double[m_NumAttributes][m_NumAttributes];
m_Priors = new double [m_NumClasses];
m_Parents = new int [m_NumAttributes - 1];
//count m_StartAttIndex/m_NumAttValues
for(int i = 0; i < m_NumAttributes; ++i){
m_StartAttIndex[i] = m_TotalAttValues;
m_NumAttValues[i] = m_Instances.attribute(i).numValues();
m_TotalAttValues += m_NumAttValues[i];
// m_CondiPriors[i] = new long[m_NumAttValues[i]];
}
m_TotalAttValues -= m_NumClasses;
//count m_CondiCounts;
m_CondiCounts = new long [m_NumClasses]
[m_TotalAttValues + m_NumClasses]
[m_TotalAttValues + m_NumClasses];
AandB = new int [m_TotalAttValues + m_NumClasses][m_TotalAttValues + m_NumClasses];
for(int i = 0; i < m_NumInstances; ++i){
Instance oneInstance = m_Instances.instance(i);
addToCount(oneInstance);
++m_Priors[(int)oneInstance.classValue()];
}
//compute the conditional mutual information
for(int i = 0; i < m_NumAttributes; ++i){
if(i == m_ClassIndex){
continue;
}
for(int j = 0; j < m_NumAttributes; ++j){
if(j == m_ClassIndex){
continue;
}
if(j == i)continue;
mutualInfo(i,j);
}
}
maxmumSpanningTree();
if(m_Debug == true){
System.out.println("=================");
System.out.println("m_NumInstances: " + m_NumInstances);
System.out.println("m_NumAttributes: " + m_NumAttributes);
System.out.println("m_NumClasses: " + m_NumClasses);
System.out.println("m_ClassIndex: " + m_ClassIndex);
System.out.println("m_TotalAttValues: " + m_TotalAttValues);
System.out.print("m_StartAttIndex[]: ");
for(int i = 0; i < m_NumAttributes; ++i){
System.out.print(m_StartAttIndex[i] + " ");
}
System.out.println();
System.out.print("m_NumAttValues[]: ");
for(int i = 0; i < m_NumAttributes; ++i){
System.out.print(m_NumAttValues[i] + " ");
}
System.out.println();
System.out.print("m_Priors[]: ");
for(int i = 0; i < m_Priors.length; ++i){
System.out.print(m_Priors[i] + " ");
}
System.out.println();
for(int i = 0; i < m_CondiCounts.length; ++i){
System.out.println("m_CondiCounts[][][] in class: " + i);
for(int j = 0; j < m_CondiCounts[i].length; ++j){
for(int k = 0; k < m_CondiCounts[i][j].length; ++k){
System.out.print(m_CondiCounts[i][j][k] + " ");
}
System.out.println();
}
}
System.out.println();
System.out.println("m_CondiMutualInfo[][]");
for(int i = 0; i < m_CondiMutualInfo.length; ++i){
for(int j = 0; j < m_CondiMutualInfo[i].length; ++j){
System.out.printf("%.4f ",m_CondiMutualInfo[i][j]);
}
System.out.println();
}
System.out.print("m_Parents[]: ");
for(int i = 0; i < m_Parents.length; ++i){
System.out.print(m_Parents[i] + " ");
}
System.out.println();
System.out.println("=================");
}
}
/**
* generate a maxmum spanning tree using the m_CondiMutualInfo[][][] by prime algorithm
* the result save in m_Parents
*/
private void maxmumSpanningTree() {
int rootIndex = m_Root;
//when not given a root
if(rootIndex == -1){
rootIndex = (int)Math.random() % m_Parents.length;
}
//allocate memory for variable
int[] adjVex = new int[m_NumAttributes - 1];
double[] maxCost = new double[m_NumAttributes - 1];
//initialize
for(int i = 0; i < m_Parents.length; ++i){
if(i != rootIndex ){
adjVex[i] = rootIndex ;
maxCost[i] = m_CondiMutualInfo[rootIndex ][i];
}
}
//take the root in U set and mark it is a root in m_Parents
maxCost[rootIndex ] = 0;
m_Parents[rootIndex ] = -1;
//for other attribute vex
for(int i = 1; i <m_Parents.length; ++i){
//get index of the most distance vex to U set and selected as current vex
int nextVex = Max(maxCost);
//set index of root of current vex which is the most far from current vex in V-U set
m_Parents[nextVex] = adjVex[nextVex];
//take the current vex in the U set
maxCost[nextVex] = 0;
//refresh the distance between current vex and each vex in V-U set
for(int j = 0; j < m_Parents.length; ++j){
if(m_CondiMutualInfo[nextVex][j] > maxCost[j] && maxCost[j] > 0){
maxCost[j] = m_CondiMutualInfo[nextVex][j];
adjVex[j] = nextVex;
}
}//j
}//i
}//method
private int Max(double []maxEdge){
int result = 0;
double max = 0;
for(int i = 0; i < m_Parents.length; ++i){
if(maxEdge[i] > max){
max = maxEdge[i];
result = i;
}
}
return result;
}
/**
* add a instance to the m_Counts
* @param the instance to be counted
*/
private void addToCount(Instance oneInstance) {
for(int j = 0; j < m_NumAttributes; ++j){
for(int k = 0; k < m_NumAttributes; ++k){
++AandB[(int)oneInstance.value(j) + m_StartAttIndex[j]][(int)oneInstance.value(k) + m_StartAttIndex[k]];
++m_CondiCounts[(int)oneInstance.classValue()]
[(int)oneInstance.value(j) + m_StartAttIndex[j]]
[(int)oneInstance.value(k) + m_StartAttIndex[k]];
}
}
}
/**
* compute the mutual information between two Attribute
* @param the index of first attribute to be compute the mutual information
* @param the index of second attribute to be compute the mutual information
*/
private void mutualInfo(int att2, int att1) {
// System.out.println("计算的两个属性为" + att1 +" "+att2);
for(int j = 0; j < m_Instances.attribute(att1).numValues(); ++j){
for(int k = 0; k < m_Instances.attribute(att2).numValues(); ++k){
if(AandB[j + m_StartAttIndex[att1]][k + m_StartAttIndex[att2]] == 0){
continue;
}
for(int c = 0; c < m_NumClasses; ++c){
// m_CondiMutualInfo[att1][att2] +=
// laplaceSmooth(m_CondiCounts[c][j + m_StartAttIndex[att1]][k + m_StartAttIndex[att2]], m_NumInstances, m_NumAttValues[att1] * m_NumAttValues[att2] * m_NumClasses)
// *Math.abs(Math.log(
// laplaceSmooth(m_CondiCounts[c][j + m_StartAttIndex[att1]][k + m_StartAttIndex[att2]], m_Priors[c], m_NumAttValues[att1] * m_NumAttValues[att2])
// /((laplaceSmooth(m_CondiCounts[c][j + m_StartAttIndex[att1]][j + m_StartAttIndex[att1]], m_Priors[c], m_NumAttValues[att1]) * laplaceSmooth(m_CondiCounts[c][k + m_StartAttIndex[att2]][k + m_StartAttIndex[att2]], m_Priors[c], m_NumAttValues[att2])))
// ));
double temp = 1;
temp = laplaceSmooth(m_CondiCounts[c][j + m_StartAttIndex[att1]][k + m_StartAttIndex[att2]], m_Priors[c], m_NumAttValues[att1] * m_NumAttValues[att2]);
temp /= laplaceSmooth(m_CondiCounts[c][j + m_StartAttIndex[att1]][j + m_StartAttIndex[att1]], m_Priors[c], m_NumAttValues[att1]);
temp /= laplaceSmooth(m_CondiCounts[c][k + m_StartAttIndex[att2]][k + m_StartAttIndex[att2]], m_Priors[c], m_NumAttValues[att2]);
temp = Math.log(temp);
temp = Math.abs(temp);
temp *= laplaceSmooth(m_CondiCounts[c][j + m_StartAttIndex[att1]][k + m_StartAttIndex[att2]], m_NumInstances, m_NumAttValues[att1] * m_NumAttValues[att2] * m_NumClasses);
m_CondiMutualInfo[att1][att2] += temp;
// System.out.printf("%d %d : %.6f\n",att1,att2,m_CondiMutualInfo[att1][att2]);
}
}
}
}
/**
* smooth the division by Laplace method to eliminate zero
* @param numerator
* @param denumerator
* @param number of value of numerator
* @return result operated by laplace smooth
*/
private double laplaceSmooth(double numerator, double denumerator, int i) {
return (numerator + 1)/(denumerator + i);
}
/* (non-Javadoc)
* @see weka.classifiers.Classifier#distributionForInstance(weka.core.Instance)
*/
public double[] distributionForInstance(Instance instance) throws Exception {
double []result = new double[m_NumClasses];
int attIndex,rootIndex;
//compute the possibility for each class, i is the index of class
for(int i = 0; i < m_NumClasses; ++i){
//initialize each class with its prior possibility
result[i] = laplaceSmooth(m_Priors[i], m_NumInstances, m_NumClasses);
Enumeration attributes = instance.enumerateAttributes();
//for each Attribute compute the possibility according by m_Parents[]
for(int j = 0; j < m_Parents.length; ++j){
//get a attribute from the attributes set
Attribute attribute = (Attribute) attributes.nextElement();
//get index of this attribute in the instances
attIndex = attribute.index();
//skip the class attribute
if(attIndex == m_ClassIndex){attribute = (Attribute) attributes.nextElement();}
//get the index of root of this attribute
rootIndex = m_Parents[attIndex];
//another formula for the root attribute
if(rootIndex == -1 ){
result[i] *= laplaceSmooth(m_CondiCounts[i][m_StartAttIndex[attIndex] + (int)instance.value(attIndex)][m_StartAttIndex[attIndex] + (int)instance.value(attIndex)], m_Priors[i], m_NumClasses * m_NumAttValues[attIndex]);
continue;
}
System.out.println(attIndex + " " + rootIndex);
result[i] *= laplaceSmooth(m_CondiCounts[i]
[m_StartAttIndex[rootIndex] + (int)instance.value(rootIndex)]
[m_StartAttIndex[attIndex] + (int)instance.value(attIndex)],
m_CondiCounts[i]
[m_StartAttIndex[rootIndex] + (int)instance.value(rootIndex)]
[m_StartAttIndex[rootIndex] + (int)instance.value(rootIndex)],
m_NumClasses * m_NumAttValues[attIndex] * m_NumAttValues[rootIndex]);
}
}
return result;
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
// class
result.enable(Capability.NOMINAL_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
return result;
}
/**
* Returns a string describing this classifier
*
* @return a description of the classifier suitable for displaying in the
* explorer/experimenter gui
*/
public String globalInfo() {
return "Class for building and using a Tree Augmented Naive Bayes(TAN) "
+ "classifier.This method outperforms naive Bayes, yet at the same "
+ "time maintains the computational simplicity(no search involved) "
+ "and robustness that characterize naive Bayes.\n\n"
+ "For more information, see\n\n"
+ "Friedman, N. & Goldszmidt, M. (1996). Building classifiers using "
+ "Bayesian networks. In: The Proceedings of the National Conference "
+ "on Artificial Intelligence(pp.1277-1284).Menlo Park, CA:AAAI Press."
+ "also see \n\n Friedman, N., Geiger,D. & Goldszmidt, M. (1997). "
+ "Bayesian Network Classifiers. Machine Learning, Vol.29,pp.131-163";
} // End of globalInfo()
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
runClassifier(new TAN_W_2_0(), argv);
}
}
weka:Naive Bayes Classifier
最新推荐文章于 2021-02-25 18:33:00 发布