mnist_cnn训练保存模型然后去识别手写数字

mnist是很多人入门机器/深度学习的入门数据集,但是只是用来测试模型和入门学习,而忽略了mnist是一个非常好的数字识别的库。

那么我使用一个非常简单,大概5-6层卷积+池化再加几层全连接的结构来训练一下mnist,然后保存下模型,当我想识别一个字符的时候就可以直接读取这个模型,然后识别这个字符了。

首先是网络模型

net =slim.repeat(net,1,slim.conv2d, 32, [3, 3], scope = 'conv1')
net = slim.max_pool2d(net,[3,3],scope ='pool1',stride = 2)
'''
14*14*32
'''
net = slim.repeat(net, 1, slim.conv2d, 64, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [3, 3], scope='pool2',stride = 2)
'''
7*7*64
'''
net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv3')
net = slim.max_pool2d(net, [3, 3], scope='pool3',stride = 2,padding="VALID")
'''
4*4*128
'''
net = slim.repeat(net, 1, slim.conv2d, 256, [3, 3], scope='conv4')
'''
4*4*256
'''
net = slim.flatten(net, scope='flatten')
net = slim.dropout(net, keep_prob=0.8,
                   is_training=self._is_training)
net = slim.fully_connected(net, 1024, scope='fc1')
net = slim.fully_connected(net, 64, scope='fc2')
net = slim.fully_connected(net, self.num_classes,
                           activation_fn=None, scope='fc3')

然后定义输入的张量的shape是[None,784],标签是[None],然后将这个输入的tensor转化一下shape,转化成可以进行卷积操作的shape

inputs = tf.placeholder(tf.float32, shape=[None, 784], name='inputs')
labels = tf.placeholder(tf.int32, shape=[None], name='labels')
cls_model = model_mnist.Model(is_training=True, num_classes=10)
image = tf.reshape(inputs,[-1,28,28,1])

然后识别的时候将图片转化为[1,784]的格式,一次识别一张的话。

import numpy as np
import tensorflow as tf
import cv2
import os
import time

model_ckpt_path = "D:/all_model/mnist_model/model.ckpt"

def main(_):
    with tf.Session() as sess:
        ckpt_path = model_ckpt_path
        saver = tf.train.import_meta_graph(ckpt_path + '.meta')
        saver.restore(sess, ckpt_path)
        inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
        classes = tf.get_default_graph().get_tensor_by_name('classes:0')
        image = cv2.imread("D:/5.jpg", cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image, (28, 28))
        image_np = np.resize(image,[1,784])
        predicted_label = sess.run(classes, feed_dict={inputs: image_np})
        print(predicted_label)
if __name__ == '__main__':
    tf.app.run()

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值