import io
import os.path
import tkinter as tk
from os import listdir
from tkinter import messagebox
import os
import PIL
import numpy as np
from PIL import Image
import cv2
import joblib
class digitCanvas: #
# 1.入口文件,设置各种参数
def __init__(self,title="数字画板",rootW=300):
self.title=title
self.root=tk.Tk()
self.rootW=rootW
self.rootH=400
self.rootX=300
self.canvasW=self.rootW
self.canvasH=int(self.rootH/2)
self.rootY=300
self.textWidth=20
self.isLoadModel=False
self.knn=None
self.__ui()
# 2.设置软件的界面
def __ui(self):
self.root.title(self.title)
self.root.geometry(f"{self.rootW}x{self.rootH}+{self.rootX}+{self.rootY}")
# 设置画板
canvas=tk.Canvas(self.root,width=self.canvasW,height=self.canvasH,bg="#fff")
canvas.grid(row=0,column=0,columnspan=5)
self.lastx=None
self.lasty=None
canvas.bind("<B1-Motion>",self.__draw)
canvas.bind("<ButtonRelease-1>",self.__release)
self.canvas=canvas
# 设置输入框的文字
frame=tk.Frame(self.root)
frame.grid(row=1,column=0,columnspan=5,pady=10)
# 设置提示文字
text=tk.Label(frame,text="请输入(0-9)的数字:")
text.pack(side="left")
# 设置输入框
input = tk.Entry(frame)
input.pack(side="left")
self.input=input
# 设置清空按钮按钮
clearBtn=tk.Button(self.root,text="清空画布",command=self.__clearFun)
clearBtn.grid(row=2,column=0,pady=10)
self.clearBtn = clearBtn
# 设置保存按钮按钮
saveBtn = tk.Button(self.root, text="保存",command=self.__saveFun)
saveBtn.grid(row=2, column=1, pady=10)
self.saveBtn = saveBtn
# 设置训练按钮按钮
trainBtn = tk.Button(self.root, text="开始训练",command=self.__train)
trainBtn.grid(row=2, column=2, pady=10)
self.trainBtn = trainBtn
# 设置预测按钮按钮
preBtn = tk.Button(self.root, text="预测",command=self.__preFun)
preBtn.grid(row=2, column=3, pady=10)
self.preBtn = preBtn
# 设置载入按钮按钮
loadBtn = tk.Button(self.root, text="载入模型",command=self.__loadModel)
loadBtn.grid(row=2, column=4, pady=10)
self.loadBtn = loadBtn
# 提示消息
notice = tk.Label(self.root, text="请开始操作",fg="red")
notice.place(x=200,y=350)
self.notice=notice
self.root.mainloop()
# 3. 写字
# 3.1 开始写字
def __draw(self, e):
if self.lastx and self.lasty:
self.canvas.create_line(self.lastx, self.lasty, e.x, e.y, fill="black", width=self.textWidth, capstyle="round")
self.lastx = e.x
self.lasty = e.y
# 3.2.写字完毕
def __release(self, e):
self.lastx = None
self.lasty = None
# 3.3 清空画布
def __clearFun(self):
self.canvas.delete("all")
self.notice.config(text="画布已清空")
# 4. 保存内容
def __saveFun(self):
label = self.input.get()
if not label.isdigit() or int(label) > 9 or int(label) < 0:
messagebox.showerror(message="请输入0-9之间的数字")
return
dirname = f"images/{label}"
if not os.path.exists(dirname):
os.makedirs(dirname)
filename = f"{label}_{len(os.listdir(dirname)) + 1}.png"
fullname = dirname + "/" + filename
# 调用该方法,将画布中的图形保存成指定的文件
self.__canvas_save_png(fullname)
self.notice.config(text="图像已保存")
# 4.1. 将画布保存成png
def __canvas_save_png(self, fullname):
img = self.canvas.postscript(colormode="color")
img = Image.open(io.BytesIO(img.encode("utf-8")))
img = img.convert("RGB")
img = img.convert("L")
tempname = "tempimg.png"
img.save(tempname)
imgdatas = self.__preprocess_img(tempname)
imgdatas = (imgdatas.reshape((28, 28)) * 255).astype(np.uint8)
Image.fromarray(imgdatas).save(fullname)
if os.path.exists(tempname):
os.remove(tempname)
# 4.2.预处理 读取图形内容 缩放转换
def __preprocess_img(self, filepath):
# 将原始图片的像素读入到内存里面
img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
# 获取图像边缘
img = 255 - img
find, _ = cv2.findContours(img, 1, 2)
cont = max(find, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(cont)
short = img[y:y + h, x:x + w]
# 缩放
x, y = short.shape
scale = 20 / (max(x, y))
resize = cv2.resize(short, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
# 将缩放的图片放到28x28的中央
y, x = resize.shape
newx = (28 - x) // 2
newy = (28 - y) // 2
newimg = np.zeros((28, 28), dtype=np.uint8)
newimg[newy:newy + y, newx:newx + x] = resize
newimg = newimg / 255
return newimg.flatten()
# 5. 加载数据,训练数据,评估数据,优化参数,保存模型
def __train(self):
self.notice.config(text="开始训练...")
# 1. 加载数据数据
datas, target = self.__loadDatas()
# 2. 开始训练 # knn
# 2.1 导入相应的库,算法、评估模型(写你的代码)
# 2.2 拆分数据集(写你的代码)
# 2.3 进行评估,从最好的,最好的距离算法,是否加入权重
best_score = 0
best_k = 2
best_p = 1
best_w = ""
#(写你的代码)
print(f"最高分是({best_score}):k={best_k},w={best_w},p={best_p}")
# 2.4 通过最好的参数,创建模型(写你的代码)
# 2.5 将拟合的模型,进行保存(写你的代码)
self.notice.config(text="训练完成,模型已生成,要预测请载入模型")
# 5.1加载数据
def __loadDatas(self):
rootdir = "images"
datas = []
target = []
for dirname in os.listdir(rootdir):
for filename in os.listdir(rootdir + "/" + dirname):
fullname = rootdir + "/" + dirname + "/" + filename
data = self.__preprocess_img(fullname)
datas.append(data)
target.append(int(dirname))
return np.array(datas), np.array(target)
# 6.载入模型
def __loadModel(self):
if self.isLoadModel:
self.notice.config(text="模型已经载入了")
return
self.isLoadModel = True
self.knn = joblib.load("knn_model.pkl")
# 7. 预测函数
def __preFun(self):
#1.先载入训练好的模型
if not self.isLoadModel:
messagebox.showerror(message="请先载入模型")
return
# 2.将预测的数字转换成标准的格式
self.__canvas_save_png("pre.png")
# 3. 将转换完的图片,转换成矩阵,用于运算
predata=self.__preprocess_img("pre.png").reshape(1,-1)
# 4. 将图片数据用KNN算法进行预测
result=self.knn.predict(predata)
# 5. 将预测的结果输出
self.notice.config(text=f"预测的结果是{result}")
obj=digitCanvas(rootW=500)补充代码缺失部分,使其完成手写数字识别的功能
最新发布