from numpy import *
import operator
from os importlistdir
import numpy as np
class Bayes:
def __init__(self):#初始化
self.length=-1 #length 用于判断是否进行了训练
self.labelcount=dict() #存储标签
self.vectorcount=dict() #存储数据向量
def fit(self,dataSet:list,labels:list):
if(len(dataSet)!=len(labels)):
raise ValueError("您输入的测试数组和类别数组长度不一致!")
self.length=len(dataSet[0])#测试数据特征值的长度
labelsnum=len(labels)#得到所有类别的数量
noRepeatlabels=set(labels)#集合去重 不重复类别的数量
for item in noRepeatlabels:
thislabel=item
self.labelcount[thislabel]=labels.count(thislabel)/labelsnum#得到当前类别在有重复类别中的比例
for vector ,label inzip(dataSet,labels):
if(label not in vectorcount):
self.vectorcount[label]=[]
self.vectorcount[label].append(vector)
print("训练结束!")
return self
def btest(self,TestData,labelSet):
if (self.length==-1):
raise ValueError("还未进行训练!!!先训练!!!")
#计算TestData分别为各个类别的概率
lbDict=dict()
for thislb in labelSet:
p=1
alllabel=self.labelcount[thislb]
allvector=self.vectorcount[thislb]
vnum=len(allvector)
allvector=numpy.array(allvector).T
for index inrange(0,len(TestData)):
vector=list(allvector[index])
p*=vector.count(TestData[index])/vnum
lbDict[thislb]=p*alllabel
thislable=sort(lbDict,key=lambdax:ibDict[x],reverse=True)[0]
return thislabel
#加载数据
def datatoarray(fname):
arr=[]
fh=open(fname)
for i in range(0,32):
thisline=fh.readline()
for j in range(0,32):
arr.apprnd(int(thisline[j]))
return arr
#建立一个函数取文件名前缀
def setlabel(fname):
filestr=fname.split(".")[0]
label=int(filestr.split("_")[0])
return label
#建立训练数据
def traindata():
labels=[]
trainfile=listdir("")
num=len(trainfile)
#长度为1024列,每一行存储一个文件
#用一个数组存储所有训练数据,行:文件总数,列:1024
trainarr=zeros((num,1024))
for i in range(0,num):
thisfname=trainfile[i]
thislabel=seplabel(thisfname)
labels.append(thislabel)
trainarr[i,:]=datatoarray(""+thisfname)
return trainarr,labels
bys=Bayes()
#训练数据
train_data,labels=traindata()
bys.fit(train_data,labels)
#测试
thisdata=datatoarray("path")
labelsall=[]
result=bys.btest(thisdata,labelsall)
print(result)
#识别多个手写体数字(批量测试)
testfileall=listdir("path")
num=len(testfileall) c1
for i in range(0,num):
thisfilename=testfileall[i]
thislabel=setlabel("path"+thisfilename)
thisdataarray=datatoarray("path"+thisfilename)
label=bys.btest(thisdataarray,labelsall)
print("该数字正确的是"+str(thislabel)+",识别出来的数字是: "+str(label))
if(label!=thislabel):
x+=1
print(x)
print("错误率是:"+str(x/num))