本章的前期工作以及神经网络的搭建:https://blog.csdn.net/ileopard/article/details/102763645
一、可视化界面设计
使用 tkinter来设计可视化界面
1.新建窗体
from tkinter import Label, Menu, DoubleVar, Button, Tk, filedialog
window = Tk() # 创建窗口
window.title("用户页面") # 窗口标题
window.geometry('240x360') # 窗口大小,小写字母x
# 这里可以在窗体内添加其他的控件
# 以上是窗口的主体
window.mainloop() # 结束(不停循环刷新)
2.添加控件(Label, Menu, DoubleVar, Button)
# 最后这个菜单栏没有用到,但是还是把它放在这,以后可能会用到
# ---------------窗口菜单栏
menubar = Menu(window) # 在窗口上添加菜单栏
filemenu = Menu(menubar, tearoff=0) # filemenu放在menu中
submenu = Menu(filemenu) # submenu放在filemenu中
ssubmenu = Menu(submenu) # ssubmenu放在submenu中
menubar.add_cascade(label='File', menu=filemenu) # add_cascade用来创建下拉栏,filemenu命名为File
filemenu.add_command(label='Open', command=Open_image) # add_command用来创建命令栏,不可有子项
filemenu.add_cascade(label='1', menu=submenu) # submenu 命名为1
submenu.add_cascade(label='2', menu=ssubmenu) # ssubmenu 命名为2
window.config(menu=menubar) # 创建完毕
# --------------------------
下面是本次所用到的控件:
# label,如果要在label中设置图片,记得一定要设置参数bitmap
Input_image = Label(width=200,
height=200,
bitmap='warning',
bg='white').grid(row=0, column=0, padx=20)
# 使用grid布局,像表格一样的布局,其中padx表示x距离外部边界的大小,ipadx表示与内部的
# sticky表示位置,w表示西,e表示东
testLabel = Label(window,
text="testAccuracy: ", # 文本
font=('Arial', 10), # 字体和大小
width=10,
height=2, # 字体所占的宽度和高度
).grid(row=1, column=0, sticky='w', pady=5, padx=36)
testAccuracy = Label(window,
textvar=textTest, # 文本
font=('Arial', 10), # 字体和大小
width=10,
height=2, # 字体所占的宽度和高度
bg='white'
).grid(row=1, column=0, sticky='e', pady=5, padx=36)
textTest.set(0.0)
# 使用Button。
startB = Button(
window,
text='开始',
width=8, height=2,
command=application # 执行函数体,而不是得到函数执行的结果
).grid(row=3, column=0, pady=5)
这样大致就做好了界面设计
二、具体功能的实现
大致流程:
下面主要进行三步:
- 从电脑中输入图片(Open_image())
- 预处理输入图片()
- 加载之前训练好的模型进行预测
1.从电脑中输入图片
- 使用filedialog从电脑中选择图片,返回绝对路径。
- 之后通过该路径,打开图片,将其放入可视化界面的Input_image中。
- 最后返回该图片
# 打开电脑中的图片
def Open_image():
global Input_image, File
File = filedialog.askopenfilename(parent=window,
initialdir=ImagePath,
title='Choose an image.')
img = Image.open(File)
img_resized = img.resize((28 * 4, 28 * 4), Image.ANTIALIAS)
filename = ImageTk.PhotoImage(img_resized)
Input_image = Label(image=filename)
Input_image.image = filename
Input_image.grid(row=0, column=0)
return img_resized
2.预处理输入图片
- 将图片resize为28*28大小
- 将图片灰度化
- 由于模型的要求是黑底白字,但输入的图是白底黑字,所以需要对每个像素点的值改为 255 减去原值以得到互补的反色。
- 把图片reshape成一维数组(784个像素点)
- 将现有的RGB图从0-255之间的数变为0-1之间的浮点数
- 返回预处理好的图片
# 预处理函数
def predicted(img):
# 将图片resize为28*28大小
reIm = img.resize((28, 28), Image.ANTIALIAS)
# 将图片灰度化
im_arr = np.array(reIm.convert('L'))
threshold = 50 # 设定合理的阈值
# 二值化
for i in range(28):
for j in range(28):
im_arr[i][j] = 255 - im_arr[i][j]
if im_arr[i][j] < threshold:
im_arr[i][j] = 0
else:
im_arr[i][j] = 255
# 将图片转化为一维数组(784个像素点)
nm_arr = im_arr.reshape([1, 784])
nm_arr = nm_arr.astype(np.float32)
# 将现有的RGB图从0-255之间的数变为0-1之间的浮点数
img_ready = np.multiply(nm_arr, 1.0 / 255.0)
return img_ready
3.加载模型进行预测
- 复现之前定义的计算图(神经网络),记得占位
- 通过checkpoint文件找到最新保存的模型位置
- 进行预测,返回预测值
# 加载模型
def restore_model(testPicArr):
# 复现之前定义的计算图
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
y = mnist_forward.forward(x, None)
# 得到概率最大的预测值
preValue = tf.argmax(y, 1)
# 计算模型在测试集上的准确率
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 实现滑动平均模型,参数MOVING_AVERAGE_DECAY用于控制模型更新的速度。训练过程会对每一个变量维护一个影子变量。
# 这个影子变量的初始值就是相应变量的初始值,每次更新时,影子变量就会随之更新
variable_averages = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
# 通过checkpoint文件找到最新保存的模型位置
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
preValue = sess.run(preValue, feed_dict={x: testPicArr})
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
textTest.set(accuracy_score)
return preValue
else:
print("No checkpoint file found")
return -1
三、预测结果
这样可视化预测就完成了。
但是有一个问题就是该模型只能用于预测白底黑字的手写数字图片,还需要改进。