weka:Naive Bayes Classifier

/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 */

/*
 *    TAN_W_2_0.java
 *    NB. This is the TAN of Friedman, Geiger, and Goldszmidt.
 *        "Bayesian Network Classifiers",
 *        Machine Learning, Vol. 29, 131-163, 1997.
 * 
 *    Author: Zishuo Wang
 *    Version: 2.0
 *    
 * 
 */


package weka.classifiers.bayes;

import java.util.*;

import weka.core.*;
import weka.core.Capabilities.Capability;
import weka.classifiers.*;


/**
 * Class for building and using a Tree Augmented Naive Bayes(TAN) classifier.
 * This method outperforms naive Bayes, yet at the same time maintains the
 * computational simplicity (no search involved) and robustness that
 * characterize naive Bayes. For more information, see
 * <p>
 * Friedman,N.,Geiger,D. & Goldszmidt,M. (1997). Bayesian Network Classifiers
 * Published in Machine Learning(Vol.29,pp.131-163).
 * 
 * Valid options are:
 * <p>
 * 
 * -E num <br>
 * The estimation strategies of probabilities. Valid values are: 
 * 0 For conditional probabilities, using M-estimation plus LaPlace estimation,
 *   otherwise only using LaPlace estimation. 
 * 1 For any probability, using Laplace estimation. 
 * 2 For conditional probabilities, only using M-estimation,
 *   otherwise only using LaPlace estimation. 
 * 3 If any probability nearly equals 0, using the constant EPSLON instead. 
 * (default: 0).
 * <p>
 * 
 * -M <br>
 * If set, delete all the instances with any missing value (default: false).
 * <p>
 * 
 * -R <br>
 * Choose the root node for building the maximum spanning tree 
 * (default: set by random).
 * <p>
 * 
 * @author Zhihai Wang (zhhwang@bjtu.edu.cn)
 * @version $Revision: 3.1.0 $
 */
/**
 * @author WangZiShuo
 *
 */
public class TAN_W_2_0 extends Classifier{

	/**
	 * 
	 */
	private static final long serialVersionUID = 6763153202302282131L;

	/** The copy of the training instances. */
	protected Instances m_Instances;

	/** The number of instances in the training instances. */
	private double m_NumInstances;

	/** The number of trainings with valid class values observed in dataset.*/
	private double m_SumInstances = 0;

	/** The number of attributes, including the class. */
	private int m_NumAttributes;

	/** The number of class values. */
	protected int m_NumClasses;

	/** The index of the class attribute. */
	private int m_ClassIndex;

	/** The counts for each class value. */
	private double[] m_Priors;

	/** 
	 * The sums of attribute-class counts.
	 * m_CondiPriors[c][k] is the same as m_CondiCounts[c][k][k]
	 */
	private long [][] m_CondiPriors;

	/** For m_NumClasses * m_TotalAttValues * m_TotalAttValues. */
	private long[][][] m_CondiCounts;

	/** The number of values for all attributes, not including class. */
	private int m_TotalAttValues;


	/** The starting index (in m_CondiCounts matrix) of each attribute. */
	private int[] m_StartAttIndex;

	/** The number of values for each attribute. */
	private int[] m_NumAttValues;

	/** Count for P(ai, aj).*/
	private int [][] AandB;

	/** The Smoothing parameter for M-estimation */
	private final double SMOOTHING = 5.0;


	/** 
	 * The counts (frequency) of each attribute value for the dataset.
	 *  Here for security, but it can be used for weighting.  
	 */
	//use AandB[ai][ai] to replace m_Frequencies[ai] 14th.August 2008
	//private double[] m_Frequencies;

	/** The matrix of conditional mutual information */
	private double[][] m_CondiMutualInfo;

	/** The minimum item in the matrix of conditional mutual information */
	private double EPSILON = 1.0E-4;

	/** The array to keep track of which attribute has which parent. (Tree) */
	private int[] m_Parents;

	/**
	 * The smoothing strategy of estimation. Valid values are: 0 For any
	 * probability, using M-estimation & Laplace, otherwise using LaPlace
	 * estimation. 1 For any probability, using Laplace estimation. 2 For any
	 * probability, using M-estimation, and otherwise using LaPlace estimation.. 3
	 * If any probability nearly equals 0, using the constant EPSLON instead. (If
	 * any prior probability nearly equals 0, then throws an Exception.) (default:
	 * 0).
	 */
	private int m_Estimation = 0;

	/** If set, delete all the instances with any missing value (default: false). */
	private boolean m_DelMissing = true;

	/**
	 * Choose the root node for building the maximum spanning tree 
	 * (default: m_Root = -1, i.e., set by random).
	 */
	private int m_Root = -1;

	/** The number of arcs in current traning dataset, only for toString()  */
	private int m_Arcs = 0;

	/** The number of instances with missing values, only for toString(). */
	private double m_NumOfMissings = 0;

	@Override
	public void buildClassifier(Instances data) throws Exception {
		// can classifier handle the data?
		getCapabilities().testWithFail(data);

		// remove instances with missing class
		m_Instances = new Instances(data);
		m_Instances.deleteWithMissingClass();
		m_DelMissing = true;

		//设置与计数类变量
		m_NumInstances = m_Instances.numInstances();
		m_NumAttributes = m_Instances.numAttributes();
		m_NumClasses = m_Instances.numClasses();
		m_ClassIndex = m_Instances.classIndex();

		//
		m_StartAttIndex = new int[m_NumAttributes];
		m_NumAttValues = new int [m_NumAttributes];
		m_CondiMutualInfo = new double[m_NumAttributes][m_NumAttributes];
		m_Priors = new double [m_NumClasses];
		m_Parents = new int [m_NumAttributes - 1];
		

		//count m_StartAttIndex/m_NumAttValues
		for(int i = 0; i < m_NumAttributes; ++i){
			m_StartAttIndex[i] = m_TotalAttValues;
			m_NumAttValues[i] = m_Instances.attribute(i).numValues();
			m_TotalAttValues += m_NumAttValues[i];		
			//			m_CondiPriors[i] = new long[m_NumAttValues[i]];
		}
		m_TotalAttValues -= m_NumClasses;

		//count m_CondiCounts;
		m_CondiCounts = new long [m_NumClasses]
		                          [m_TotalAttValues + m_NumClasses]
		                           [m_TotalAttValues + m_NumClasses];
		AandB = new int [m_TotalAttValues + m_NumClasses][m_TotalAttValues + m_NumClasses];
		for(int i = 0; i < m_NumInstances; ++i){
			Instance oneInstance = m_Instances.instance(i);
			addToCount(oneInstance);
			++m_Priors[(int)oneInstance.classValue()];
		}		

		//compute the conditional mutual information	
		for(int i = 0; i < m_NumAttributes; ++i){
			if(i == m_ClassIndex){
				continue;
			}
			for(int j = 0; j < m_NumAttributes; ++j){
				if(j == m_ClassIndex){
					continue;
				}
				if(j == i)continue;
				mutualInfo(i,j);
			}
		}

		maxmumSpanningTree();

		if(m_Debug == true){
			System.out.println("=================");
			System.out.println("m_NumInstances: " + m_NumInstances);
			System.out.println("m_NumAttributes: " + m_NumAttributes);
			System.out.println("m_NumClasses: " + m_NumClasses);
			System.out.println("m_ClassIndex: " + m_ClassIndex);
			System.out.println("m_TotalAttValues: " + m_TotalAttValues);


			System.out.print("m_StartAttIndex[]: ");
			for(int i = 0; i < m_NumAttributes; ++i){
				System.out.print(m_StartAttIndex[i] + " ");
			}
			System.out.println();

			System.out.print("m_NumAttValues[]: ");
			for(int i = 0; i < m_NumAttributes; ++i){
				System.out.print(m_NumAttValues[i] + " ");
			}
			System.out.println();

			System.out.print("m_Priors[]: ");
			for(int i = 0; i < m_Priors.length; ++i){
				System.out.print(m_Priors[i] + " ");
			}
			System.out.println();

			for(int i = 0; i < m_CondiCounts.length; ++i){
				System.out.println("m_CondiCounts[][][] in class: " + i);
				for(int j = 0; j < m_CondiCounts[i].length; ++j){
					for(int k = 0; k < m_CondiCounts[i][j].length; ++k){
						System.out.print(m_CondiCounts[i][j][k] + " ");
					}
					System.out.println();
				}
			}
			System.out.println();

			System.out.println("m_CondiMutualInfo[][]");
			for(int i = 0; i < m_CondiMutualInfo.length; ++i){
				for(int j = 0; j < m_CondiMutualInfo[i].length; ++j){
					System.out.printf("%.4f ",m_CondiMutualInfo[i][j]);					
				}
				System.out.println();
			}

			System.out.print("m_Parents[]: ");
			for(int i = 0; i < m_Parents.length; ++i){
				System.out.print(m_Parents[i] + " ");
			}
			System.out.println();
			System.out.println("=================");
		}


	}

	/**
	 * generate a maxmum spanning tree using the m_CondiMutualInfo[][][] by prime algorithm
	 * the result save in m_Parents 
	 */
	private void maxmumSpanningTree() {
		int rootIndex = m_Root;

		//when not given a root
		if(rootIndex  == -1){
			rootIndex = (int)Math.random() % m_Parents.length;
		}
		//allocate memory for variable
		int[] adjVex = new int[m_NumAttributes - 1];
		double[] maxCost = new double[m_NumAttributes - 1];	

		//initialize
		for(int i = 0; i < m_Parents.length; ++i){
			if(i != rootIndex ){
				adjVex[i] = rootIndex ;
				maxCost[i] = m_CondiMutualInfo[rootIndex ][i];
			}
		}

		//take the root in U set and mark it is a root in m_Parents
		maxCost[rootIndex ] = 0;
		m_Parents[rootIndex ] = -1;

		//for other attribute vex
		for(int i = 1; i <m_Parents.length; ++i){

			//get index of the most distance vex to U set and selected as current vex
			int nextVex = Max(maxCost);

			//set index of root of current vex which is the most far from current vex in V-U set
			m_Parents[nextVex] = adjVex[nextVex];

			//take the current vex in the U set
			maxCost[nextVex] = 0;

			//refresh the distance between current vex and each vex in V-U set
			for(int j = 0; j < m_Parents.length; ++j){
				if(m_CondiMutualInfo[nextVex][j] > maxCost[j] && maxCost[j] > 0){
					maxCost[j] = m_CondiMutualInfo[nextVex][j];
					adjVex[j] = nextVex;
				}
			}//j
		}//i
	}//method


	private int Max(double []maxEdge){
		int result = 0;
		double max = 0;
		for(int i = 0; i < m_Parents.length; ++i){
			if(maxEdge[i] > max){
				max = maxEdge[i];
				result = i;
			}
		}
		return result;
	}


	/**
	 * add a instance to the m_Counts
	 * @param the instance  to be counted
	 */
	private void addToCount(Instance oneInstance) {
		for(int j = 0; j < m_NumAttributes; ++j){
			for(int k = 0; k < m_NumAttributes; ++k){
				++AandB[(int)oneInstance.value(j) + m_StartAttIndex[j]][(int)oneInstance.value(k) + m_StartAttIndex[k]];
				++m_CondiCounts[(int)oneInstance.classValue()]
				                [(int)oneInstance.value(j) + m_StartAttIndex[j]]
				                 [(int)oneInstance.value(k) + m_StartAttIndex[k]];
			}
		}	
	}

	/**
	 * compute the mutual information between two Attribute  
	 * @param the index of first attribute to be compute the mutual information
	 * @param the index of second attribute to be compute the mutual information
	 */
	private void mutualInfo(int att2, int att1) {
		//		System.out.println("计算的两个属性为" + att1 +" "+att2);
		for(int j = 0; j < m_Instances.attribute(att1).numValues(); ++j){
			for(int k = 0; k < m_Instances.attribute(att2).numValues(); ++k){
				if(AandB[j + m_StartAttIndex[att1]][k + m_StartAttIndex[att2]] == 0){
					continue;
				}
				for(int c = 0; c < m_NumClasses; ++c){
//					m_CondiMutualInfo[att1][att2] +=
//							laplaceSmooth(m_CondiCounts[c][j + m_StartAttIndex[att1]][k + m_StartAttIndex[att2]], m_NumInstances, m_NumAttValues[att1] * m_NumAttValues[att2] * m_NumClasses)
//							*Math.abs(Math.log(
//									laplaceSmooth(m_CondiCounts[c][j + m_StartAttIndex[att1]][k + m_StartAttIndex[att2]], m_Priors[c],  m_NumAttValues[att1] * m_NumAttValues[att2])
//									/((laplaceSmooth(m_CondiCounts[c][j + m_StartAttIndex[att1]][j + m_StartAttIndex[att1]], m_Priors[c], m_NumAttValues[att1]) * laplaceSmooth(m_CondiCounts[c][k + m_StartAttIndex[att2]][k + m_StartAttIndex[att2]], m_Priors[c], m_NumAttValues[att2])))
//							));
					double temp = 1;
					temp = laplaceSmooth(m_CondiCounts[c][j + m_StartAttIndex[att1]][k + m_StartAttIndex[att2]], m_Priors[c],  m_NumAttValues[att1] * m_NumAttValues[att2]);
					temp /= laplaceSmooth(m_CondiCounts[c][j + m_StartAttIndex[att1]][j + m_StartAttIndex[att1]], m_Priors[c], m_NumAttValues[att1]);
					temp /= laplaceSmooth(m_CondiCounts[c][k + m_StartAttIndex[att2]][k + m_StartAttIndex[att2]], m_Priors[c], m_NumAttValues[att2]);
					temp = Math.log(temp);
					temp = Math.abs(temp);
					temp *= laplaceSmooth(m_CondiCounts[c][j + m_StartAttIndex[att1]][k + m_StartAttIndex[att2]], m_NumInstances, m_NumAttValues[att1] * m_NumAttValues[att2] * m_NumClasses);
					m_CondiMutualInfo[att1][att2] += temp;
//					System.out.printf("%d %d : %.6f\n",att1,att2,m_CondiMutualInfo[att1][att2]);
				}
			}

		}
	}

	/**
	 * smooth the division by Laplace method to eliminate zero 
	 * @param numerator
	 * @param denumerator
	 * @param number of value of numerator
	 * @return result operated by laplace smooth
	 */
	private double laplaceSmooth(double numerator, double denumerator, int i) {
		return (numerator + 1)/(denumerator + i);

	}

	/* (non-Javadoc)
	 * @see weka.classifiers.Classifier#distributionForInstance(weka.core.Instance)
	 */
	public double[] distributionForInstance(Instance instance) throws Exception {

		double []result = new double[m_NumClasses];
		int attIndex,rootIndex;

		//compute the possibility for each class, i is the index of class
		for(int i = 0; i < m_NumClasses; ++i){

			//initialize each class with its prior possibility
			result[i] = laplaceSmooth(m_Priors[i], m_NumInstances, m_NumClasses);			

			Enumeration attributes = instance.enumerateAttributes();

			//for each Attribute compute the possibility according by m_Parents[]
			for(int j = 0; j < m_Parents.length; ++j){
				//get a attribute from the attributes set
				Attribute attribute = (Attribute) attributes.nextElement();
				//get index of this attribute in the instances
				attIndex = attribute.index();
				//skip the class attribute
				if(attIndex == m_ClassIndex){attribute = (Attribute) attributes.nextElement();}
				//get the index of root of this attribute
				rootIndex = m_Parents[attIndex];
				//another formula for the root attribute				
				if(rootIndex == -1 ){
					result[i] *=  laplaceSmooth(m_CondiCounts[i][m_StartAttIndex[attIndex] + (int)instance.value(attIndex)][m_StartAttIndex[attIndex] + (int)instance.value(attIndex)], m_Priors[i], m_NumClasses * m_NumAttValues[attIndex]);
					continue;
				}
				System.out.println(attIndex + " " + rootIndex);
				result[i] *= laplaceSmooth(m_CondiCounts[i]
				                                         [m_StartAttIndex[rootIndex] + (int)instance.value(rootIndex)]
				                                          [m_StartAttIndex[attIndex] + (int)instance.value(attIndex)], 
				                                          m_CondiCounts[i]
				                                                        [m_StartAttIndex[rootIndex] + (int)instance.value(rootIndex)]
				                                                         [m_StartAttIndex[rootIndex] + (int)instance.value(rootIndex)], 
				                                                         m_NumClasses * m_NumAttValues[attIndex] * m_NumAttValues[rootIndex]);
			}
		}
		return result;
	}

	/**
	 * Returns default capabilities of the classifier.
	 *
	 * @return      the capabilities of this classifier
	 */
	public Capabilities getCapabilities() {
		Capabilities result = super.getCapabilities();
		result.disableAll();

		// attributes
		result.enable(Capability.NOMINAL_ATTRIBUTES);
		result.enable(Capability.NUMERIC_ATTRIBUTES);
		result.enable(Capability.MISSING_VALUES);

		// class
		result.enable(Capability.NOMINAL_CLASS);
		result.enable(Capability.MISSING_CLASS_VALUES);

		return result;
	}

	/**
	 * Returns a string describing this classifier
	 * 
	 * @return a description of the classifier suitable for displaying in the
	 *         explorer/experimenter gui
	 */
	public String globalInfo() {
		return "Class for building and using a Tree Augmented Naive Bayes(TAN) "
		+ "classifier.This method outperforms naive Bayes, yet at the same "
		+ "time maintains the computational simplicity(no search involved) "
		+ "and robustness that characterize naive Bayes.\n\n"
		+ "For more information, see\n\n"
		+ "Friedman, N. & Goldszmidt, M. (1996). Building classifiers using "
		+ "Bayesian networks. In: The Proceedings of the National Conference "
		+ "on Artificial Intelligence(pp.1277-1284).Menlo Park, CA:AAAI Press."
		+ "also see \n\n  Friedman, N., Geiger,D. & Goldszmidt, M. (1997). "
		+ "Bayesian Network Classifiers. Machine Learning, Vol.29,pp.131-163";

	} // End of globalInfo()

	/**
	 * Main method for testing this class.
	 *
	 * @param argv the options
	 */
	public static void main(String [] argv) {
		runClassifier(new TAN_W_2_0(), argv);
	}
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值