本节解决的问题
1.如何对输入的真实图片,输出预测结果?
2.如何制作数据集,实现特定应用?
答:
1.可以利用第五节中的网络结构,再添加部分代码(mnist_app.py), 实现输入手写数字图片输出识别结果。
2.利用tf中的 tfrecords文件。
6.1输入手写数字图片输出识别结果
6.1.1断点续训
断点续训可以解决神经网络训练被中断后,恢复神经网络训练时可以按之前的训练结果继续训练。实现方式是在反向传播中加入ckpt,代码如下
# 代码在第五节的反向传播代码基础上更改,省略了部分内容
def backward(mnist):
# 省略反向传播网络结构
# 省略滑动平均等配置
# 在Session()中加入ckpt实现断点续训练。
# 实例化类
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
# 恢复保存的训练参数,实现断点续训
saver.restore(sess, ckpt.model_checkpoint_path)
for i in range(STEPS):
# 省略循环训练和保存训练结果过程。
6.1.1图片预处理
首先,需要将图片预处理为全连接网络输入数据的格式。我们通过下面代码中的 pre_pic()函数实现。mnist数据集(本课所用的数字识别数据集),数据是黑底白字,黑底用0表示,白字用0~1之间浮点数表示,越接近1 颜色越白(这里老师说的应该是视觉直观上)。( 《TensorFlow实战Google学习框架》 书中描述,书中是接近1为黑色。和http://yann.lecun.com/exdb/mnist/ 网页上是说像素是0到255 255为黑色。应该是说图片数据中的1为有值。反色是因为在我们平时写的数字是白底黑字.总之要知道为什么在代码里反色。)
然后在restore_mode()函数中使用mnist_forward.py中定义的y, 并喂入处理之后的图片,模型预测的概率。
mnist_app.py 代码如下
#mnist_app.py 代码
import tensorflow as tf
import numpy as np
from PIL import Image
import mnist_backward
import mnist_forward
def restore_model(testPicArr):
# tf.Graph().as_default() 应该是为了处理多个程序同时调用
# mnist_forward.forward() 的情况,
# 屏蔽with tf.Graph().as_default() as tg:
# 一个窗口运行mnist_app,在另一个窗口同时运行mnist_backward,app会报错
# https://tensorflow.google.cn/api_docs/python/tf/Graph?hl=en
# 根据tf官网
# The default graph is a property of the current thread. If you create a new thread,
# and wish to use the default graph in that thread, you must explicitly add a with
# g.as_default(): in that thread's function.
with tf.Graph().as_default() as tg:
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y = mnist_forward.forward(x, None)
preValue = tf.argmax(y, 1)
variable_averages= tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
# variables_to_restore()
variable_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
# 恢复网络结构
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
preValue = sess.run(preValue, feed_dict={x:testPicArr})
return preValue
else:
print 'No checkpoint file found!'
return -1
# 预处理, 包括resize, 转变灰度图,二值化。
def pre_pic(picName):
img = Image.open(picName)
reIm = img.resize((28, 28), Image.ANTIALIAS)
im_arr = np.array(reIm.convert('L'))
threshold = 50
# im_arr[i][j] only have 0 or 255???
for i in range(28):
for j in range(28):
# 反色处理,reb 和mnist 对黑的定义相反
# reb中 0为黑,mnist 中 1为黑
im_arr[i][j] = 255- im_arr[i][j]
if im_arr[i][j] < threshold :
im_arr[i][j] =0
else:
im_arr[i][j] =255
nm_arr = im_arr.reshape([1, 784])
nm_arr = nm_arr.astype(np.float32)
img_ready = np.multiply(nm_arr, 1.0/255.0)
return img_ready
def application():
testNum = int(input('input the number of test pictures:'))
for i range(testNum):
testPic = raw_input('the path of test picture:')
testPicArr = pre_pic(testPic)
preValue = restore_model(testPicArr)
print 'The prediction number is:', preValue
def main():
application()
if __name__ == '__main__':
main()
输出截图如下:
需要注意的是程序对mnist数据集外的图片识别率不是很高。
6.2 制作数据集
6.1.1tfrecords文件
tfrecords是一种二进制文件, 可先将图片和标签制作成为该格式的文件,使用tf.records进行数据读取,提高内存利用率。
用tf.train.Example的协议存储训练数据。训练数据的特征用键值对的形式表示。
如: ’img_raw' :值 ’label' :值 值是Byteslist/FloatList/Int64List
用SerializeToString() 把数据序列化成字符串存储。
伪代码如下,注意伪代码中把图片像素值除以255了 和mnist数据集是0到1之间的数对应
# 生成tfrecords 文件
writer = tf.python_io.TFRecordWriter(tfRecordName) # 新建一个writer
for 循环遍历每张图和标签 :
example = tf.train.Example(feature=tf.train.Features(feature={
'img_raw':tf.train.Feature(bytes_list=tf.train.Byteslist(value=[img_raw])),
'label':tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
})) # 把每张图片和标签封装到example中
wirter.write(example.SerializeToString()) # 把example 进行序列化
# 解析tfrecords文件
fliename_queue =tf.train.string_input_producer([tfRecord_path])
reader = tf.TFRecordReader() # 新建一个reader
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,features={
'img_raw':tf.FixedLenFeature([],tf.string),
'label':tf.FixedLenFeature([10],tf.int64)
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img.set_shape([784])
# 注意这里把图片像素值除以255了 和mnist数据集是0到1之间的数对应
img = tf.cast(img, tf.float32)*(1/255)
label = tf.cast(features['label'], tf.float32)
本节课代码和第五节课的区别为
数据集生成部分的代码:
# coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
import os
#tf.disable_v2_behavior()
image_train_path = './mnist_data_jpg/mnist_train_jpg_60000/'
label_train_path = './mnist_data_jpg/mnist_train_jpg_60000.txt'
tfRecord_train = './data/mnist_train.tfrecords'
image_test_path = './mnist_data_jpg/mnist_test_jpg_10000/'
label_test_path = './mnist_data_jpg/mnist_test_jpg_10000.txt'
tfRecord_test = './data/mnist_test.tfrecords'
data_path = './data'
resize_height = 28
resize_width =28
# 生成tfrecords文件
def write_tfRecord(tfRecordName, image_path, label_path):
# 新建一个writer
writer = tf.python_io.TFRecordWriter(tfRecordName)
num_pic = 0
f = open(label_path, 'r')
contents =f.readlines()
f.close()
# 循环遍历每张图和标签
for content in contents:
value = content.split()
img_path =image_path + value[0]
img = Image.open(img_path)
img_raw = img.tobytes()
labels = [0] * 10
labels[int(value[1])] = 1
# 把图片和标签封装到example
example = tf.train.Example(features=tf.train.Features(feature={
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
}))
# example 序列化
writer.write(example.SerializeToString())
num_pic += 1
print("the number of picture:", num_pic)
writer.close()
print("write tfrecord successful")
def generate_tfRecord():
isExists = os.path.exists(data_path)
if not isExists:
os.makedirs(data_path)
print('the directory was created successfully')
else:
print('directory already exists')
write_tfRecord(tfRecord_train, image_train_path, label_train_path)
write_tfRecord(tfRecord_test, image_test_path, label_test_path)
# 解析tfrecords文件
def read_tfRecord(tfRecord_path):
# 函数生成一个先入先出的队列,文件阅读器会使用它来读取数据
filename_queue = tf.train.string_input_producer([tfRecord_path], shuffle=True)
# 新建一个reader
reader = tf.TFRecordReader()
# 解序列化,标签和图片的键名应该和制作tfrecords的键名相同,其中标签给出几分类
_, serialized_example =reader.read(filename_queue)
# 将tf.train.Example协议内存块(protocol buffer)解析为张量
features = tf.parse_single_example(serialized_example,
features={
'label':tf.FixedLenFeature([10], tf.int64),
'img_raw':tf.FixedLenFeature([],tf.string)
})
# 将 img_raw 字符串转化为8位无符号整形
img = tf.decode_raw(features['img_raw'],tf.uint8)
img.set_shape([784])
# 注意这里把图片像素值除以255了 和mnist数据集是0到1之间的数对应
img = tf.cast(img, tf.float32) * (1. / 255)
label = tf.cast(features['label'], tf.float32)
return img, label
def get_tfrecord(num,isTrain=True):
if isTrain:
tfRecord_path = tfRecord_train
else:
tfRecord_path = tfRecord_test
img, label = read_tfRecord(tfRecord_path)
#随机读取一个batch的数据
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size=num,
num_threads=2,
capacity=1000,
min_after_dequeue=700)
# 返回的图片和标签为随机抽取的batch_size组
return img_batch, label_batch
def main():
generate_tfRecord()
if __name__ == "__main__":
main()