Shark源码分析(九):朴素贝叶斯算法
关于这个算法,我之前也有写过博客来介绍过。但是Shark在实现时,它只考虑到了输入为连续属性值的情况,而没有考虑到离散属性值的情况。至于连续属性值的情况该如何计算,可以参考下我的博客。
AbstractDistribution类
既然我们需要计算分布情况,那我们需要有一个类来表示分布。AbstractDistribution类就是表示所有分布的基类,该类定义在文件<include/shark/Rng/AbstractDistribution.h>
。
class AbstractDistribution
{
public:
virtual ~AbstractDistribution() {}
//给定一个输入,计算其对应的概率值
virtual double p(double x) const = 0;
//计算输入值概率的log值
//safeLog是一个有界的log函数,避免当输入过小时,其log值也太小
virtual double logP(double x) const { return safeLog(p(x)); }
};
之后肯定会有很多表示不同分布类来继承这个类,这里就不过多的介绍了。
NBClassifier类
该类主要定义了在给定类标签的情况下,每一维属性的分布情况。该类定义在<include/shark/Models/NBClassifier.h>
中。
template <class InputType = RealVector, class OutputType = unsigned int>
class NBClassifier :
public AbstractModel<InputType, OutputType>,
private boost::noncopyable
{
private:
typedef AbstractModel<InputType, OutputType> base_type;
public:
typedef typename base_type::BatchInputType BatchInputType;
typedef typename base_type::BatchOutputType BatchOutputType;
//数据中类别的分布,实际上是每个类别的比例
typedef std::vector<double> ClassPriorsType;
typedef boost::shared_ptr<AbstractDistribution> AbstractDistPtr;
//给定类别是,每一维属性的分布
typedef std::vector<std::vector<AbstractDistPtr> > FeatureDistributionsType;
//第一个参数表示类别的个数,第二个参数表示数据的维度
typedef std::pair<std::size_t, std::size_t> DistSizeType;
//类的构造函数
NBClassifier(std::size_t classSize, std::size_t featureSize)
{
SIZE_CHECK(classSize > 0u);
SIZE_CHECK(featureSize > 0u);
for (std::size_t i = 0; i < classSize; ++i)
{
std::vector<AbstractDistPtr> featureDist;
for (std::size_t j = 0; j < fe