作者:Irain
QQ:2573396010
微信:18802080892
1 下载数据包
GitHub链接(速度慢):手写Beyes算法之手写体数字识别
码云链接(速度快):手写Beyes(knn)算法之手写体数字识别
2 解压数据包
3 打开代码文件
4 图片转为文本代码
import os
from PIL import Image
def pictureTotxt(data_name):
'''
处理bmp图片转为txt文本
Args:
data_name: 数据类型名称
Return: None
'''
if not os.path.exists("%s_digital_txt" %(data_name)): # 创建txt文本文件夹
os.mkdir("%s_digital_txt" %(data_name))
lists = os.listdir("%s_digital_pictures" %(data_name)) # 获取所有手写体图片名称
for list in lists: # 图片转换文本
im = Image.open("%s_digital_pictures/%s" % (data_name,list)) # 打开手写体图片
fh = open("%s_digital_txt/%s.txt" % (data_name,list.split(".")[0]),"a") # 打开txt文本
width = im.size[0] # 图片宽
height = im.size[1] # 图片高
for k in range(0,width): # 存入txt文本
for j in range(0,height):
cl = im.getpixel((k,j)) # bmp图片的像素只有一个数
if(cl == 0): # 0:黑色
fh.write("0") # 黑色为0
else:
fh.write("1") # 白色为1
fh.write("\n")
fh.close() # 关闭文件
5 加载训练数据
def datatoarray(fname):
'''
加载txt文本数据
Args:
fname: 文本名称
Return: None
'''
arr = []
fh = open(fname) # 打开txt文本
for i in range(0,28):
thisline = fh.readline() # 读取文本内容
for j in range(0,28):
arr.append(int(thisline[j]))
fh.close()
return arr
def seplabel(fname):
'''
获取文本类别
Args:
fname: 文本名称、名称格式例子:0_0.txt
Return: None
'''
filestr = fname.split(".")[0]
label = int(filestr.split("_")[0])
return label
def traindata():
'''
加载训练数据
Return: None
'''
labels = []
trainfile = os.listdir("train_digital_txt/") # 获取手写体图片名
num = len(trainfile)
# 长度784列,每一行存储一个文本
trainarr = np.zeros((num ,784)) # 数组存储训练数据,行:文件总数,列:784=28*28
print("相对路径:train_digital_pictures/") # 转换数据目录
print("所有训练txt文本数量:",num) # num:所有txt文本数量
for i in range(0,num):
thisfname = trainfile[i]
thislabel = seplabel(thisfname)
labels.append(thislabel) # 记录txt文本对应的数字编号
trainarr[i,:] = datatoarray("train_digital_txt/%s" %(thisfname)) # 加载txt文本数据
num = 0
for arrs in trainarr: # 统计有用数据数量
for arr in arrs:
if arr == 1:
num += 1
break
print("统计有用数据数量 :",num)
return trainarr, labels
6 训练数据
class Bayes(object):
def __init__(self):
self.length = -1
self.labelweight = dict() # 权重:类别占总类别总数的比例
self.vectorcount = dict()
def fit(self, dataSet:list, labels:list):
print("-"*40,"开始训练数据","-"*35)
if(len(dataSet) != len(labels)):
raise ValueError("测试组与类别组的长度不一样")
self.length = len(dataSet[0]) # 测试数据特征值的长度
labelsnum = len(labels) # 类别所有的数量
norlabels = set(labels) # 不重复类别的数量
for item in norlabels:
thislabel = item
self.labelweight[thislabel] = labels.count(thislabel)/labelsnum # 权重:每个类别的占比
for vector, label in zip(dataSet, labels):
if(label not in self.vectorcount):
self.vectorcount[label] = []
self.vectorcount[label].append(vector)
for i in range(0,10):
print("类别%s所占比例%s"%(i,self.labelweight[i]))
print("-"*40,"数据训练结束","-"*35)
return self
7 抽某一个文件进行测试
train_data, labels = traindata()
def test_one(train_data,labels,testfile):
'''
抽某一个文件进行测试
Args:
trainarr: 训练集
labels: 训练集类别
testfile: 文本名称
Return: None
'''
labelsall = [0,1,2,3,4,5,6,7,8,9]
testdata=datatoarray("test_digital_txt/" + testfile) # 获取测试样本
bys = Bayes()
bys.fit(train_data,labels)
rst = bys.btest(testdata,labelsall)
print("%s的测试结果:%s"%(testfile,rst))
8 测试数据
测试结果:
准确率:84.1%、正确次数:1851
出错率高的数字:2、3、4、5、8、9
数字6出错次数最少11、出错率0.5%、错判次数最多的数字0
数字9出错次数最多68、出错率3.1%、错判次数最多的数字7
程序运行消耗时间:108.9分
def testdata(self):
'''
Beyes算法测试
Return: None
'''
print("-"*40,"开始测试数据","-"*35)
start = time.time() # 开始时间
#trainarr, labels = traindata()
errors = [] # 所有出错的数字类别
errors_num = [] # 每个数字类别出错的次数
errors_rating = [] # 每个数字类别的错误率
errors_maxdigital = [] # 每个数字类别错判次数最多的数字
maxdigital = [] # 每个数字类别所有错判的数字
testlist = os.listdir("test_digital_txt/" ) # 获取手写体图片名
num = len(testlist)
print("文件存储相对路径:test_digital_txt/") # 转换数据目录
print("所有测试txt文本数量:",num) # num:所有txt文本数量
# 长度784列,每一行存储一个文件
# 以一个数组存储所有训练数据,行:文件总数,列:784
testarr = np.zeros((num,784))
ten = 0
print("")
print("-"*40,"类别",ten,"-"*40)
print("")
for i in range(0,num):
thisfname = testlist[i]
thisdata = datatoarray("test_digital_txt/%s" %(thisfname))
labelsall = [0,1,2,3,4,5,6,7,8,9]
rst = self.btest(thisdata,labelsall)
if str(rst) != thisfname.split("_")[0]: # 预测错误
print("%s样本预测出错:" %(thisfname),rst)
errors.append(int(thisfname.split("_")[0])) # 记录出错数字类别
maxdigital.append(rst) # 判错数字
if (i+1)%220 == 0: # 某个数字类别预测结束
errors_num.append(errors.count(ten)) # 计算某数字类别的出错次数
errors_rating.append(errors_num[ten]/2200) # 计算某数字类别出错率
if maxdigital: # 某数字类别判错次数最大的数字
errors_maxdigital.append(max(maxdigital, key=maxdigital.count))
else:
errors_maxdigital.append("None") # 没有,存入None
print("数字%s出错情况:出错次数%s、出错率%.1f%s、错判次数最多的数字%s" %(ten,errors_num[ten],(errors_rating[ten]*100),"%",errors_maxdigital[ten]))
maxdigital = [] # 清空,记录下一个数字类别
ten += 1
accuracy = (2200 - len(errors)) # 预测正确的所有次数
accuracy_rating = accuracy/2200 # 正确率
elapsed = (time.time() - start) # 开始时间与结束时间之差
print("-"*40,"数据测试结束","-"*35)
print("准确率:%.1f%、正确次数:%s" %((accuracy_rating*100),accuracy))
for k in range(0,10):
print("数字%s出错情况:出错次数%s、出错率%.1f%s、错判次数最多的数字%s" %(k,errors_num[k],(errors_rating[k]*100),"%",errors_maxdigital[k]))
print("程序运行消耗时间:%.1f分" %(elapsed/60))
发布:2020年5月20日