@贝叶斯分类器(Python实现+详细完整源码和原理)——补充原理和修正错误
导读
昨天看了一个关于贝叶斯分类器例子的python代码博客,原博客(https://blog.csdn.net/qq_25948717/article/details/81744277)已经很好了,但是个人在通过原理的公式推算和代码的实现后发现其中少了一些东西,导致结果出了一些错误(个人拙见),所以今天自己补充一下。
贝叶斯原理
原理前面的博客已经说了,大家只需要记住下面这个公式就是了
贝叶斯公式:
下面通过实例来解释这个公式的意义以及代码想要实现的功能:
通过贝叶斯原理要实现的就是给出一组特征,例如【较长,甜,不黄】来判断这些特征描述的是什么水果?
根据贝叶斯原理计算得到各自水果的概率然后比较哪种水果的概率较大,则最终的结果就是哪种水果。得到的结果可能是这样【香蕉:0.552552,橘子:0.214565,其他水果:0.112452】,那么最终的结果就是符合该特征【较长,甜,不黄】的水果最有可能是香蕉。
通过贝叶斯公式以香蕉为例就是如下,同理也可计算橘子和其他水果的概率:
而其中的各自的概率如下:
所以最终要算出***P[香蕉|较长、甜、不黄]***的概率就要知道***P[较长]***、***P[甜]***、***P[不黄]***、***P[香蕉]***、***P[较长|香蕉]***、***P[甜|香蕉]***、***P[不黄|香蕉]***的概率。大家先开动脑筋想一下怎么算,一会儿代码里面会一一实现。
python代码实现(代码修改了上个博客中的错误地方,其他地方大同小异)
代码一:随机生成的20组特征,类似于【较长,甜,不黄】或者【不长,不甜,不黄】
import random
def random_attr(pair):
#生成0-1之间的随机数
return pair[random.randint(0,1)]
def gen_attrs():
#特征值的取值集合
sets = [('long','not_long'),('sweet','not_sweet'),('yellow','not_yellow')]
test_datasets = []
for i in range(20):
#使用map函数来生成一组特征值
test_datasets.append(list(map(random_attr,sets)))
return test_datasets
#print(gen_attrs())
代码二:贝叶斯分类器计算概率的结果
####训练数据集---->合适参数
datasets = {'banala':{'long':400,'not_long':100,'sweet':350,'not_sweet':150,'yellow':450,'not_yellow':50},
'orange':{'long':0,'not_long':300,'sweet':150,'not_sweet':150,'yellow':300,'not_yellow':0},
'other_fruit':{'long':100,'not_long':100,'sweet':150,'not_sweet':50,'yellow':50,'not_yellow':150}
}
def count_total(data):
'''计算各种水果的总数
return {‘banala’:500 ...}'''
count = {}
total = 0
for fruit in data:
'''因为水果要么甜要么不甜,可以用 这两种特征来统计总数'''
count[fruit] = data[fruit]['sweet'] + data[fruit]['not_sweet']
total += count[fruit]
return count,total
#categories,simpleTotal = count_total(datasets)
#print(categories,simpleTotal)
###########################################################
def cal_base_rates(data):
'''计算各种水果的先验概率
return {‘banala’:0.5 ...}'''
categories,total = count_total(data)
cal_base_rates = {}
for label in categories:
priori_prob = categories[label]/total
cal_base_rates[label] = priori_prob
return cal_base_rates
#Prio = cal_base_rates(datasets)
#print(Prio)
############################################################
def likelihold_prob(data):
'''计算各个特征值在已知水果下的概率(likelihood probabilities)
{'banala':{'long':0.8}...}'''
count,_ = count_total(data)
likelihold = {}
for fruit in data:
'''创建一个临时的字典,临时存储各个特征值的概率'''
attr_prob = {}
for attr in data[fruit]:
#计算各个特征值在已知水果下的概率
attr_prob[attr] = data[fruit][attr]/count[fruit]
likelihold[fruit] = attr_prob
return likelihold
#LikeHold = likelihold_prob(datasets)
#print(LikeHold)
############################################################
def evidence_prob(data):
'''计算特征的概率对分类结果的影响
return {'long':50%...}'''
#水果的所有特征
attrs = list(data['banala'].keys())
count,total = count_total(data)
evidence_prob = {}
#计算各种特征的概率
for attr in attrs:
attr_total = 0
for fruit in data:
attr_total += data[fruit][attr]
evidence_prob[attr] = attr_total/total
return evidence_prob
#Evidence_prob = evidence_prob(datasets)
#print(Evidence_prob)
##########################################################
#以上是训练数据用到的函数,即将数据转化为代码计算概率
##########################################################
class navie_bayes_classifier:
'''初始化贝叶斯分类器,实例化时会调用__init__函数'''
def __init__(self,data=datasets):
self._data = datasets
# self._data = {
# 'banala': {'long': 400, 'not_long': 100, 'sweet': 350, 'not_sweet': 150, 'yellow': 450, 'not_yellow': 50},
# 'orange': {'long': 0, 'not_long': 300, 'sweet': 150, 'not_sweet': 150, 'yellow': 300, 'not_yellow': 0},
# 'other_fruit': {'long': 100, 'not_long': 100, 'sweet': 150, 'not_sweet': 50, 'yellow': 50,'not_yellow': 150}
# }
self._labels = [key for key in self._data.keys()]
# self._labels = ['banala', 'orange', 'other_fruit'] 各个水果的名称组成的列表
self._priori_prob = cal_base_rates(self._data)
# self._priori_prob = {'banala': 0.5, 'orange': 0.3, 'other_fruit': 0.2} 总的水果中香蕉或者橘子或者其他水果各占的比例
self._likelihold_prob = likelihold_prob(self._data)
# self._likelihold_prob = {
# 'banala': {'long': 0.8, 'not_long': 0.2, 'sweet': 0.7, 'not_sweet': 0.3, 'yellow': 0.9, 'not_yellow': 0.1},
# 'orange': {'long': 0.0, 'not_long': 1.0, 'sweet': 0.5, 'not_sweet': 0.5, 'yellow': 1.0, 'not_yellow': 0.0},
# 'other_fruit': {'long': 0.5, 'not_long': 0.5, 'sweet': 0.75, 'not_sweet': 0.25, 'yellow': 0.25, 'not_yellow': 0.75}
# }
# 各个特征值在已知水果下的概率
self._evidence_prob = evidence_prob(self._data)
# self._evidence_prob ={'long': 0.5, 'not_long': 0.5, 'sweet': 0.65, 'not_sweet': 0.35, 'yellow': 0.8, 'not_yellow': 0.2}各种特征的概率
#下面的函数可以直接调用上面类中定义的变量
def get_label(self,length,sweetness,color):
'''获取某一组特征值的类别'''
self._attrs = [length,sweetness,color]
res = {}
for label in self._labels:
prob = 1 #取某水果占比率
#print("各个水果的占比率:",prob)
for attr in self._attrs:
#单个水果的某个特征概率除以总的某个特征概率 再乘以某水果占比率
prob*=self._priori_prob[label]*self._likelihold_prob[label][attr]/self._evidence_prob[attr]
#print(prob)
res[label] = prob
#print(res)
return res
代码三:主函数
import operator
import bayes_classfier
import generate_attires
def main():
test_datasets = generate_attires.gen_attrs()
classfier = bayes_classfier.navie_bayes_classifier()
for data in test_datasets:
print("特征值:",end='\t')
print(data) # 随机产生的特征列表包括长度、甜度、色度
print("\t预测结果:", end='\t')
res=classfier.get_label(*data)#表示多参传入
print(res)#预测属于哪种水果的概率
print('\t水果类别:',end='\t')
#对后验概率排序,输出概率最大的标签
print(sorted(res.items(),key=operator.itemgetter(1),reverse=True)[0][0])
if __name__ == '__main__':
#表示模块既可以被导入(到 Python shell 或者其他模块中),也可以作为脚本来执行。
#当模块被导入时,模块名称是文件名;而当模块作为脚本独立运行时,名称为 __main__。
#让模块既可以导入又可以执行
main()
实现结果
特征值: ['not_long', 'sweet', 'yellow']
预测结果: {'banala': 0.060576923076923084, 'orange': 0.051923076923076905, 'other_fruit': 0.0028846153846153848}
水果类别: banala
特征值: ['not_long', 'sweet', 'yellow']
预测结果: {'banala': 0.060576923076923084, 'orange': 0.051923076923076905, 'other_fruit': 0.0028846153846153848}
水果类别: banala
特征值: ['long', 'sweet', 'not_yellow']
预测结果: {'banala': 0.1076923076923077, 'orange': 0.0, 'other_fruit': 0.034615384615384624}
水果类别: banala
特征值: ['not_long', 'sweet', 'yellow']
预测结果: {'banala': 0.060576923076923084, 'orange': 0.051923076923076905, 'other_fruit': 0.0028846153846153848}
水果类别: banala
特征值: ['not_long', 'sweet', 'not_yellow']
预测结果: {'banala': 0.026923076923076925, 'orange': 0.0, 'other_fruit': 0.034615384615384624}
水果类别: other_fruit
特征值: ['not_long', 'sweet', 'not_yellow']
预测结果: {'banala': 0.026923076923076925, 'orange': 0.0, 'other_fruit': 0.034615384615384624}
水果类别: other_fruit
特征值: ['not_long', 'sweet', 'yellow']
预测结果: {'banala': 0.060576923076923084, 'orange': 0.051923076923076905, 'other_fruit': 0.0028846153846153848}
水果类别: banala
特征值: ['long', 'not_sweet', 'not_yellow']
预测结果: {'banala': 0.08571428571428573, 'orange': 0.0, 'other_fruit': 0.021428571428571436}
水果类别: banala
特征值: ['long', 'sweet', 'not_yellow']
预测结果: {'banala': 0.1076923076923077, 'orange': 0.0, 'other_fruit': 0.034615384615384624}
水果类别: banala
特征值: ['long', 'not_sweet', 'not_yellow']
预测结果: {'banala': 0.08571428571428573, 'orange': 0.0, 'other_fruit': 0.021428571428571436}
水果类别: banala
特征值: ['long', 'sweet', 'not_yellow']
预测结果: {'banala': 0.1076923076923077, 'orange': 0.0, 'other_fruit': 0.034615384615384624}
水果类别: banala
特征值: ['long', 'sweet', 'yellow']
预测结果: {'banala': 0.24230769230769234, 'orange': 0.0, 'other_fruit': 0.0028846153846153848}
水果类别: banala
特征值: ['long', 'sweet', 'not_yellow']
预测结果: {'banala': 0.1076923076923077, 'orange': 0.0, 'other_fruit': 0.034615384615384624}
水果类别: banala
特征值: ['not_long', 'sweet', 'yellow']
预测结果: {'banala': 0.060576923076923084, 'orange': 0.051923076923076905, 'other_fruit': 0.0028846153846153848}
水果类别: banala
特征值: ['not_long', 'not_sweet', 'not_yellow']
预测结果: {'banala': 0.021428571428571432, 'orange': 0.0, 'other_fruit': 0.021428571428571436}
水果类别: other_fruit
特征值: ['not_long', 'sweet', 'yellow']
预测结果: {'banala': 0.060576923076923084, 'orange': 0.051923076923076905, 'other_fruit': 0.0028846153846153848}
水果类别: banala
特征值: ['not_long', 'sweet', 'not_yellow']
预测结果: {'banala': 0.026923076923076925, 'orange': 0.0, 'other_fruit': 0.034615384615384624}
水果类别: other_fruit
特征值: ['long', 'not_sweet', 'yellow']
预测结果: {'banala': 0.1928571428571429, 'orange': 0.0, 'other_fruit': 0.001785714285714286}
水果类别: banala
特征值: ['not_long', 'sweet', 'not_yellow']
预测结果: {'banala': 0.026923076923076925, 'orange': 0.0, 'other_fruit': 0.034615384615384624}
水果类别: other_fruit
特征值: ['not_long', 'not_sweet', 'not_yellow']
预测结果: {'banala': 0.021428571428571432, 'orange': 0.0, 'other_fruit': 0.021428571428571436}
水果类别: other_fruit