python实现一个朴素贝叶斯分类方法

1.公式

上式中左边D是需要预测的测试数据属性,h是需要预测的类;右边式子分子是属性的条件概率和类别的先验概率,可以从统计训练数据中得到,分母对于所有实例都一样,可以不考虑,所有只需 ,返回最大概率的那个类别。但是如果测试数据中没有那个属性,整个预测概率会是0;此外,此式针对离散型属性进行训练,针对连续的数值型属性可以考虑分段,也可以假设其满足某种分布,比如正态分布,利用概率密度函数求概率。

2.部分改进

(1).针对测试数据中没有那个属性,可以平滑一下,比如下(针对非数值型属性):

上式中n是某个类别下的实例数,nc是此类别下的属性个数,m是此属性的取值个数,p是此属性取值出现的概率。比如一个属性:性别,取值男或女,则 m=2,p=1/2。

(2).针对连续的数值型属性,可以分段比如年龄0-10为A,10-30为B等;还可以假设它服从高斯分布(正态分布),利分布函数计算概率:

其中uij是某列数值型属性的均值,Qij是某列数值型属性样本标准差,Xi是数值属性。训练的时候只需要统计均值,样本标准差就行了,预测的时候利用。

3.python实现

  1 #!/usr/bin/python
  2 # -\*- coding: utf-8 -\*-
  3 
  4 import codecs  5 import math  6 
  7 class BayesClassifier:  8 
  9     def \_\_init\_\_(self,dataFormat):
 10         self.prior = {}#类别的先验概率
 11         self.conditional = {}#属性的条件概率
 12         # 输入的数据格式,attr表示非数值型属性,num表示数值型属性,class表示类别
 13         self.format=dataFormat.strip().split('\\t')
 14 
 15     #读取数据
 16     def readData(self,dataFile): 17         total = 0#所有实例数
 18         self.classes = {}#统计类别
 19         self.counts = {}#用来统计
 20         totals={}#统计数值型每列的和
 21         numericValues={}#数值型每列值
 22 
 23         with codecs.open(dataFile,'r','utf-8') as f:
 24             for line in f: 25                 fields=line.strip().split('\\t')
 26                 fieldSize=len(fields)
 27                 vector=\[\]
 28                 nums=\[\]
 29                 for i in range(fieldSize): 30                     if self.format\[i\]=='num':
 31                         nums.append(float(fields\[i\]))
 32                     elif self.format\[i\]=='attr':
 33                         vector.append(fields\[i\])
 34                     elif self.format\[i\]=='class':
 35                         category=fields\[i\]
 36                 total+=1
 37                 self.classes.setdefault(category,0)
 38                 self.counts.setdefault(category,{})
 39                 totals.setdefault(category,{})
 40                 numericValues.setdefault(category,{})
 41                 self.classes\[category\]+=1
 42                 #统计一条非数值型实例的属性
 43                 col=0
 44                 for columnValue in vector: 45                     col+=1
 46                     self.counts\[category\].setdefault(col,{})
 47                     self.counts\[category\]\[col\].setdefault(columnValue,0)
 48                     self.counts\[category\]\[col\]\[columnValue\]+=1
 49                 col=0
 50                 for columnValue in nums: 51                     col+=1
 52                     totals\[category\].setdefault(col,0)
 53                     totals\[category\]\[col\]+=columnValue
 54                     numericValues\[category\].setdefault(col,\[\])
 55                     numericValues\[category\]\[col\].append(columnValue)
 56 
 57         #以上统计完成,计算类别先验概率和属性条件概率
 58         #计算类的先验概率=此类的实例数/总的实例数
 59         for category,count in self.classes.items(): 60             self.prior\[category\]=count/total
 61         #计算属性的条件概率=此类中属性数/此类实例数
 62         for category,columns in self.counts.items(): 63             self.conditional.setdefault(category,{})
 64             for col,valueCounts in columns.items(): 65                 self.conditional\[category\].setdefault(col,{})
 66                 colSize=len(valueCounts)#这一列属性的取值个数(如性别取值为男和女,则colSize=2)
 67                 for attr,count in valueCounts.items(): 68                     #平滑一下
 69                     self.conditional\[category\]\[col\]\[attr\]=(count+colSize\*1/colSize)/(self.classes\[category\]+colSize)
 70         #在数值型列中计算均值和样本标准差
 71         #每列的均值
 72         self.means={}
 73         self.totals=totals
 74         for category,columns in totals.items(): 75             self.means.setdefault(category,{})
 76             for col,colSum  in columns.items(): 77                 self.means\[category\]\[col\]=colSum/self.classes\[category\]
 78         #每列的标准差
 79         self.std={}
 80         for category,columns in numericValues.items(): 81             self.std.setdefault(category,{})
 82             for col,values in columns.items(): 83                 ssd=0
 84                 mean=self.means\[category\]\[col\]
 85                 for value in values: 86                     ssd+=(value-mean)\*\*2
 87                 self.std\[category\]\[col\]=math.sqrt(ssd/(self.classes\[category\]-1))
 88 
 89 
 90     #分类,返回分类结果
 91     def classify(self,itemVector): 92         results=\[\]
 93         for category,prior in self.prior.items(): 94             prob=prior
 95             col=1
 96             for attrValue in itemVector: 97                 if self.format\[col\]=='attr':
 98                     # 如果预测数据没有这个属性,则平滑一下,不是返回0(返回0会导致整个预测结果为0)
 99                     if not attrValue in self.conditional\[category\]\[col\]:
100                         colSize=len(self.counts\[category\]\[col\])
101                         prob=prob\*(0+colSize\*1/colSize)/(self.classes\[category\]+colSize)
102                     else:
103                         prob=prob\*self.conditional\[category\]\[col\]\[attrValue\]
104                 #针对数值型,我们先得到该列均值与样本标准差,利用正态分布得到概率(假设该列数值满足正态分布)
105                 elif self.format\[col\]=='num':
106                     mean=self.means\[category\]\[col\]
107                     std=self.std\[category\]\[col\]
108                     prob=prob\*self.gaussian(mean,std,attrValue)
109                 col+=1
110 results.append((prob,category))
111         return max(results)\[1\]
112 
113     #高斯分布
114     def gaussian(self,mean,std,x):
115         sqrt2pi = math.sqrt(2 \* math.pi)
116         ePart=math.pow(math.e,-(x-mean)\*\*2/(2\*std\*\*2))
117         prob=(1.0/sqrt2pi\*std)\*ePart
118         return prob
119 
120     # 十折验证读取数据,prefix为文件名前缀,i作为测试集编号
121     def tenFoldReadData(self,prefix,testNumber):
122         total = 0  # 所有实例数
123         self.classes = {}  # 统计类别
124         self.counts = {}  # 用来统计
125         totals = {}  # 统计数值型每列的和
126         numericValues = {}  # 数值型每列值
127 
128         for i in range(1,11):
129             if i!=testNumber:
130                 filename='%s-%02s' % (prefix,i)
131                 with codecs.open(filename, 'r', 'utf-8') as f:
132                     for line in f:
133                         fields = line.strip().split('\\t')
134                         fieldSize = len(fields)
135                         vector = \[\]
136                         nums = \[\]
137                         for i in range(fieldSize):
138                             if self.format\[i\] == 'num':
139 nums.append(float(fields\[i\]))
140                             elif self.format\[i\] == 'attr':
141 vector.append(fields\[i\])
142                             elif self.format\[i\] == 'class':
143                                 category = fields\[i\]
144                         total += 1
145 self.classes.setdefault(category, 0)
146 self.counts.setdefault(category, {})
147 totals.setdefault(category, {})
148 numericValues.setdefault(category, {})
149                         self.classes\[category\] += 1
150                         # 统计一条非数值型实例的属性
151                         col = 0
152                         for columnValue in vector:
153                             col += 1
154 self.counts\[category\].setdefault(col, {})
155 self.counts\[category\]\[col\].setdefault(columnValue, 0)
156                             self.counts\[category\]\[col\]\[columnValue\] += 1
157                         col = 0
158                         for columnValue in nums:
159                             col += 1
160 totals\[category\].setdefault(col, 0)
161                             totals\[category\]\[col\] += columnValue
162 numericValues\[category\].setdefault(col, \[\])
163 numericValues\[category\]\[col\].append(columnValue)
164 
165         # 以上统计完成,计算类别先验概率和属性条件概率
166         # 计算类的先验概率=此类的实例数/总的实例数
167         for category, count in self.classes.items():
168             self.prior\[category\] = count / total
169         # 计算属性的条件概率=此类中属性数/此类实例数
170         for category, columns in self.counts.items():
171 self.conditional.setdefault(category, {})
172             for col, valueCounts in columns.items():
173 self.conditional\[category\].setdefault(col, {})
174                 colSize = len(valueCounts)  # 这一列属性的取值个数(如性别取值为男和女,则colSize=2)
175                 for attr, count in valueCounts.items():
176                     # 平滑一下
177                     self.conditional\[category\]\[col\]\[attr\] = (count + colSize \* 1 / colSize) / (
178                     self.classes\[category\] + colSize)
179         # 在数值型列中计算均值和样本标准差
180         # 每列的均值
181         self.means = {}
182         self.totals = totals
183         for category, columns in totals.items():
184 self.means.setdefault(category, {})
185             for col, colSum in columns.items():
186                 self.means\[category\]\[col\] = colSum / self.classes\[category\]
187         # 每列的标准差
188         self.std = {}
189         for category, columns in numericValues.items():
190 self.std.setdefault(category, {})
191             for col, values in columns.items():
192                 ssd = 0
193                 mean = self.means\[category\]\[col\]
194                 for value in values:
195                     ssd += (value - mean) \*\* 2
196                 self.std\[category\]\[col\] = math.sqrt(ssd / (self.classes\[category\] - 1))
197 
198     #利用十折交叉验证,测试一个桶中的数据,prefix为统计文件名前缀,testNumber为要测试的一个桶中的数据
199     def testOneBucket(self,prefix,testNumber):
200         filename='%s-%02i' % (prefix,testNumber)
201         totals={}
202         with codecs.open(filename,'r','utf-8') as f:
203             for line in f:
204                 data=line.strip().split('\\t')
205                 itemVector=\[\]
206                 classInColumn=-1
207                 for i in range(len(self.format)):
208                     if self.format\[i\]=='num':
209 itemVector.append(float(data\[i\]))
210                     elif self.format\[i\]=='attr':
211 itemVector.append(data\[i\])
212                     elif self.format\[i\]=='class':
213                         classInColumn=i
214                 realClass=data\[classInColumn\]#真实的类
215                 classifiedClass=self.classify(itemVector)#预测的类
216 totals.setdefault(realClass,{})
217 totals\[realClass\].setdefault(classifiedClass,0)
218                 totals\[realClass\]\[classifiedClass\]+=1
219         return totals
220 
221 #十折交叉验证,prefix为十个文件名字的前缀,dataForamt为数据格式
222 def tenfold(prefix,dataFormat):
223     results={}
224     for i in range(1,11):
225         classify=BayesClassifier(dataFormat)
226 classify.tenFoldReadData(prefix,i)
227         totals=classify.testOneBucket(prefix,i)
228         for key,value in totals.items():
229 results.setdefault(key,{})
230             for ckey,cvalue in value.items():
231 results\[key\].setdefault(ckey,0)
232                 results\[key\]\[ckey\]+=cvalue
233     #结果展示
234     classes=list(results.keys())
235 classes.sort()
236     print(      '\\n                 classes as: ')
237     header='                '
238     subheader='               +'
239     for cls in classes:
240         header+='%  10s '% cls
241         subheader+='\--------+'
242     print(header)
243     print(subheader)
244     total=0.0
245     correct=0.0
246     for cls in classes:
247         row=' %10s   |' % cls
248         for c2 in classes:
249             if c2 in results\[cls\]:
250                 count=results\[cls\]\[c2\]
251             else:
252                 count=0
253             row+=' %5i |' % count
254             total+=count
255             if c2==cls:
256                 correct+=count
257         print(row)
258     print(subheader)
259     print('\\n%5.3f 正确率' % ((correct\*100/total)))
260     print('总共 %i 实例'% total)
261 
262 if \_\_name\_\_\=='\_\_main\_\_':
263     #classify=BayesClassifier('num,num,num,num,num,num,num,num,class')
264     #classify.readData('dataFile')
265     #print(classify.classify(\[2,120,54,0,0,26.8,0.455,27\]))
266     tenfold('dataFilePrefix','num,num,num,num,num,num,num,num,class')#十折交叉验证
  • 7
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值