假设我们已经安装好了tensorflow。
一般在安装好tensorflow后,都会跑它的demo,而最常见的demo就是手写数字识别的demo,也就是mnist数据集。
然而我们仅仅是跑了它的demo而已,可能很多人会有和我一样的想法,如果拿来一张数字图片,如何应用我们训练的网络模型来识别出来,下面我们就以mnist的demo来实现它。
1.训练模型
首先我们要训练好模型,并且把模型model.ckpt保存到指定文件夹
saver = tf.train.Saver()
saver.save(sess, "model_data/model.ckpt")
将以上两行代码加入到训练的代码中,训练完成后保存模型即可,如果这部分有问题,你可以百度查阅资料,tensorflow怎么保存训练模型,在这里我们就不罗嗦了。
2.测试模型
我们训练好模型后,将它保存在了model_data文件夹中,你会发现文件夹中出现了4个文件
然后,我们就可以对这个模型进行测试了,将待检测图片放在images文件夹下,执行
# -*- coding:utf-8 -*-
import cv2
import tensorflow as tf
import numpy as np
from sys import path
path.append('../..')
from common import extract_mnist
#初始化单个卷积核上的参数
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0