输入手写数字输出识别结果
实现断点续训
输入真实图片,输出预测结果
实现断点续训,在 mnist_backward.py 里加入三行代码即可:
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 实现断点续训 ---------------------------------------- ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) # ---------------------------------------------------- for i in range(STEPS): xs, ys = mnist.train.next_batch(BATCH_SIZE) _, loss_v, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys}) if i % 1000 == 0: print('After %d training steps, loss on training batch is %g.' % (step, loss_v)) saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
(1)
tf.train.get_checkpoint_state(checkpoint_dir,latest_filename=None)
该函数表示如果断点文件夹中包含有效断点状态文件,则返回该文件。
参数说明:
checkpoint_dir:表示存储断点文件的目录
latest_filename=None:断点文件的可选名称,默认为“checkpoint”
(2)saver.restore(sess, ckpt.model_checkpoint_path)
该函数表示恢复当前会话,将 ckpt 中的值赋给 w 和 b。
参数说明:
sess:表示当前会话,之前保存的结果将被加载入这个会话
ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看 checkpoint 文件,看看最新的是谁,叫做什么。输入真实图片,输出预测结果:
mnist_forward.py 和 mnist_backward.py 、mnist_test.py不变,增加一个mnist_app.py
模型的要求是黑(0)底白(255)字,但输入的图是白底黑字,所以需要对每个像素点的值改为 255 减去原值以得到互补的反色。# coding:utf-8 import tensorflow as tf import numpy as np from PIL import Image import mnist_forward import mnist_backward def restore_model(testPicArr): # 创建一个默认图,在图中执行相应操作 with tf.Graph().as_default() as g: x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE]) y = mnist_forward.forward(x, None) preValue = tf.argmax(y, 1) ema = tf.train.ExponentialMovingAverage(mnist_backward.EMA_DECAY) ema_restore = ema.variables_to_restore() saver = tf.train.Saver(ema_restore) with tf.Session() as sess: # 通过checkpoint文件定位到最新保存的模型 ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) preValue_val = sess.run(preValue, feed_dict={x:testPicArr}) return preValue_val else: print('No checkpoint file found.') return -1 # 输入图片预处理函数 def pre_pic(picName): img = Image.open(picName) # 用消除锯齿的方式 resize reIm = img.resize((28, 28), Image.ANTIALIAS) # 转变为灰度图 im_arr = np.array(reIm.convert('L')) # 设定合理的阈值,对图片做二值化处理(这样以滤掉噪声,另外调试中可适当调节阈值) threshold = 50 for i in range(28): for j in range(28): im_arr[i][j] = 255 - im_arr[i][j] if (im_arr[i][j] < threshold): im_arr[i][j] = 0 else: im_arr[i][j] = 255 nm_arr = im_arr.reshape([1, 784]) nm_arr = nm_arr.astype(np.float32) im_ready = np.multiply(nm_arr, 1.0/255.0) return im_ready def application(): testNum = input('input the number of test pictures:') for i in range(int(testNum)): testPic = input('input the path of test picture:') # 对手写数字图片做预处理 testPicArr = pre_pic(testPic) # 将符合神经网络输入要求的图片喂给复现的神经网络模型,输出预测值 preValue = restore_model(testPicArr) print('The prediction number is: ', preValue) if __name__ == '__main__': application()
如果没有
with tf.Graph().as_default() as g:
会报错:input the number of test pictures:10 input the path of test picture:pic/0.png 2018-07-19 10:19:56.079652: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA 2018-07-19 10:19:56.317172: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:892] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero 2018-07-19 10:19:56.317466: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1030] Found device 0 with properties: name: GeForce 940MX major: 5 minor: 0 memoryClockRate(GHz): 1.2415 pciBusID: 0000:01:00.0 totalMemory: 1.96GiB freeMemory: 1.94GiB 2018-07-19 10:19:56.317483: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce 940MX, pci bus id: 0000:01:00.0, compute capability: 5.0) The prediction number is: [0] input the path of test picture:pic/1.png 2018-07-19 10:20:12.682054: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce 940MX, pci bus id: 0000:01:00.0, compute capability: 5.0) 2018-07-19 10:20:12.689655: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Variable_7/ExponentialMovingAverage not found in checkpoint 2018-07-19 10:20:12.690771: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Variable_4/ExponentialMovingAverage not found in checkpoint 2018-07-19 10:20:12.691076: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Variable_5/ExponentialMovingAverage not found in checkpoint 2018-07-19 10:20:12.691219: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Variable_6/ExponentialMovingAverage not found in checkpoint Traceback (most recent call last): File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1323, in _do_call return fn(*args) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1302, in _run_fn status, run_metadata) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__ c_api.TF_GetCode(self.status.status)) tensorflow.python.framework.errors_impl.NotFoundError: Key Variable_7/ExponentialMovingAverage not found in checkpoint [[Node: save_1/RestoreV2_7 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_7/tensor_names, save_1/RestoreV2_7/shape_and_slices)]] [[Node: save_1/RestoreV2_6/_3 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_22_save_1/RestoreV2_6", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]] During handling of the above exception, another exception occurred: Traceback (most recent call last): File "mnist_application.py", line 58, in <module> application() File "mnist_application.py", line 53, in application preValue = restore_model(testPicArr) File "mnist_application.py", line 21, in restore_model saver.restore(sess, ckpt.model_checkpoint_path) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1666, in restore {self.saver_def.filename_tensor_name: save_path}) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 889, in run run_metadata_ptr) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1120, in _run feed_dict_tensor, options, run_metadata) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1317, in _do_run options, run_metadata) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1336, in _do_call raise type(e)(node_def, op, message) tensorflow.python.framework.errors_impl.NotFoundError: Key Variable_7/ExponentialMovingAverage not found in checkpoint [[Node: save_1/RestoreV2_7 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_7/tensor_names, save_1/RestoreV2_7/shape_and_slices)]] [[Node: save_1/RestoreV2_6/_3 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_22_save_1/RestoreV2_6", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]] Caused by op 'save_1/RestoreV2_7', defined at: File "mnist_application.py", line 58, in <module> application() File "mnist_application.py", line 53, in application preValue = restore_model(testPicArr) File "mnist_application.py", line 14, in restore_model saver = tf.train.Saver(ema_restore) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1218, in __init__ self.build() File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1227, in build self._build(self._filename, build_save=True, build_restore=True) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1263, in _build build_save=build_save, build_restore=build_restore) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 751, in _build_internal restore_sequentially, reshape) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 427, in _AddRestoreOps tensors = self.restore_op(filename_tensor, saveable, preferred_shard) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 267, in restore_op [spec.tensor.dtype])[0]) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1021, in restore_v2 shape_and_slices=shape_and_slices, dtypes=dtypes, name=name) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper op_def=op_def) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op op_def=op_def) File "/home/rmw/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1470, in __init__ self._traceback = self._graph._extract_stack() # pylint: disable=protected-access NotFoundError (see above for traceback): Key Variable_7/ExponentialMovingAverage not found in checkpoint [[Node: save_1/RestoreV2_7 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_7/tensor_names, save_1/RestoreV2_7/shape_and_slices)]] [[Node: save_1/RestoreV2_6/_3 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_22_save_1/RestoreV2_6", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]