《tensorflow笔记》学习记录第六节

本节解决的问题

1.如何对输入的真实图片,输出预测结果?
2.如何制作数据集,实现特定应用?

答:

1.可以利用第五节中的网络结构,再添加部分代码(mnist_app.py), 实现输入手写数字图片输出识别结果。

2.利用tf中的 tfrecords文件。

6.1输入手写数字图片输出识别结果

6.1.1断点续训

断点续训可以解决神经网络训练被中断后,恢复神经网络训练时可以按之前的训练结果继续训练。实现方式是在反向传播中加入ckpt,代码如下

# 代码在第五节的反向传播代码基础上更改,省略了部分内容

def backward(mnist):


    # 省略反向传播网络结构
    # 省略滑动平均等配置
    # 在Session()中加入ckpt实现断点续训练。
	
    # 实例化类
	saver = tf.train.Saver()

	with tf.Session() as sess:
		init_op = tf.global_variables_initializer()
		sess.run(init_op)

	    ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
	    if ckpt and ckpt.model_checkpoint_path:
            # 恢复保存的训练参数,实现断点续训 
		    saver.restore(sess, ckpt.model_checkpoint_path)
		
		for i in range(STEPS):
			# 省略循环训练和保存训练结果过程。

6.1.1图片预处理

首先,需要将图片预处理为全连接网络输入数据的格式。我们通过下面代码中的 pre_pic()函数实现。mnist数据集(本课所用的数字识别数据集),数据是黑底白字,黑底用0表示,白字用0~1之间浮点数表示,越接近1 颜色越白(这里老师说的应该是视觉直观上。( 《TensorFlow实战Google学习框架》 书中描述,书中是接近1为黑色。和http://yann.lecun.com/exdb/mnist/  网页上是说像素是0到255 255为黑色。应该是说图片数据中的1为有值。反色是因为在我们平时写的数字是白底黑字.总之要知道为什么在代码里反色。)

然后在restore_mode()函数中使用mnist_forward.py中定义的y, 并喂入处理之后的图片,模型预测的概率。

mnist_app.py 代码如下

#mnist_app.py 代码

import tensorflow as tf
import numpy as np
from PIL import Image
import mnist_backward
import mnist_forward


def restore_model(testPicArr):
	# tf.Graph().as_default() 应该是为了处理多个程序同时调用
    # mnist_forward.forward() 的情况,
    # 屏蔽with tf.Graph().as_default() as tg:
    # 一个窗口运行mnist_app,在另一个窗口同时运行mnist_backward,app会报错
    # https://tensorflow.google.cn/api_docs/python/tf/Graph?hl=en
    # 根据tf官网 
    # The default graph is a property of the current thread. If you create a new thread, 
    # and wish to use the default graph in that thread, you must explicitly add a with 
    # g.as_default(): in that thread's function.

	with tf.Graph().as_default() as tg:
		x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])

		y = mnist_forward.forward(x, None)

		preValue = tf.argmax(y, 1)


		variable_averages= tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
		# variables_to_restore()
		variable_to_restore = variable_averages.variables_to_restore()
		saver = tf.train.Saver(variables_to_restore)
        
        # 恢复网络结构
		with tf.Session() as sess:
			ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
		if ckpt and ckpt.model_checkpoint_path:
			saver.restore(sess, ckpt.model_checkpoint_path)
			
			preValue = sess.run(preValue, feed_dict={x:testPicArr})
			return preValue
		else:
			print 'No checkpoint file found!'
			return -1

# 预处理, 包括resize, 转变灰度图,二值化。
def pre_pic(picName):
	img = Image.open(picName)
	reIm = img.resize((28, 28), Image.ANTIALIAS)
	im_arr = np.array(reIm.convert('L'))

	threshold = 50
	# im_arr[i][j] only have 0 or 255???

	for i in range(28):
		for j in range(28):
    # 反色处理,reb 和mnist 对黑的定义相反
    # reb中 0为黑,mnist 中 1为黑
			im_arr[i][j] = 255- im_arr[i][j]
			if im_arr[i][j] < threshold :
				im_arr[i][j] =0
			else:
				im_arr[i][j] =255

	nm_arr = im_arr.reshape([1, 784])
	nm_arr = nm_arr.astype(np.float32)

	img_ready = np.multiply(nm_arr, 1.0/255.0)

	return img_ready		


def application():
	testNum = int(input('input the number of test pictures:'))
	for i range(testNum):
		testPic = raw_input('the path of test picture:')
		
		testPicArr = pre_pic(testPic)
		
		preValue = restore_model(testPicArr)
		print 'The prediction number is:', preValue


def main():
	application()


if __name__ == '__main__':
	main()

输出截图如下:

 需要注意的是程序对mnist数据集外的图片识别率不是很高。

6.2 制作数据集

6.1.1tfrecords文件

tfrecords是一种二进制文件, 可先将图片和标签制作成为该格式的文件,使用tf.records进行数据读取,提高内存利用率。

用tf.train.Example的协议存储训练数据。训练数据的特征用键值对的形式表示。

如: ’img_raw' :值      ’label' :值     值是Byteslist/FloatList/Int64List

用SerializeToString() 把数据序列化成字符串存储。

伪代码如下,注意伪代码中把图片像素值除以255了 和mnist数据集是0到1之间的数对应

# 生成tfrecords 文件
writer = tf.python_io.TFRecordWriter(tfRecordName) # 新建一个writer
for 循环遍历每张图和标签 :
    example = tf.train.Example(feature=tf.train.Features(feature={
        'img_raw':tf.train.Feature(bytes_list=tf.train.Byteslist(value=[img_raw])),
        'label':tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
    }))  # 把每张图片和标签封装到example中
    wirter.write(example.SerializeToString()) # 把example 进行序列化

# 解析tfrecords文件
fliename_queue =tf.train.string_input_producer([tfRecord_path])
reader = tf.TFRecordReader() # 新建一个reader
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,features={
    'img_raw':tf.FixedLenFeature([],tf.string),
    'label':tf.FixedLenFeature([10],tf.int64)
})
img = tf.decode_raw(features['img_raw'], tf.uint8)
img.set_shape([784])
# 注意这里把图片像素值除以255了 和mnist数据集是0到1之间的数对应
img = tf.cast(img, tf.float32)*(1/255) 
label = tf.cast(features['label'], tf.float32)

本节课代码和第五节课的区别为

数据集生成部分的代码:

# coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
import  os 



#tf.disable_v2_behavior()
image_train_path = './mnist_data_jpg/mnist_train_jpg_60000/'
label_train_path = './mnist_data_jpg/mnist_train_jpg_60000.txt'
tfRecord_train = './data/mnist_train.tfrecords'
image_test_path = './mnist_data_jpg/mnist_test_jpg_10000/'
label_test_path = './mnist_data_jpg/mnist_test_jpg_10000.txt'
tfRecord_test = './data/mnist_test.tfrecords'
data_path = './data'
resize_height = 28
resize_width =28


# 生成tfrecords文件
def write_tfRecord(tfRecordName, image_path, label_path):
    # 新建一个writer
    writer = tf.python_io.TFRecordWriter(tfRecordName)
    num_pic = 0
    f = open(label_path, 'r')
    contents =f.readlines()
    f.close()
    # 循环遍历每张图和标签
    for content in contents:
        value = content.split()
        img_path =image_path + value[0]
        img = Image.open(img_path)
        img_raw = img.tobytes()
        labels = [0] * 10
        labels[int(value[1])] = 1
        # 把图片和标签封装到example
        example = tf.train.Example(features=tf.train.Features(feature={
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
            'label':   tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
        }))
        # example 序列化
        writer.write(example.SerializeToString())
        num_pic += 1
        print("the number of picture:", num_pic)

        writer.close()
        print("write tfrecord successful")

def generate_tfRecord():
    isExists = os.path.exists(data_path)
    if not isExists:
        os.makedirs(data_path)
        print('the directory was created successfully')
    else:
        print('directory already exists')
        write_tfRecord(tfRecord_train, image_train_path, label_train_path)
        write_tfRecord(tfRecord_test, image_test_path, label_test_path)


# 解析tfrecords文件
def read_tfRecord(tfRecord_path):
    # 函数生成一个先入先出的队列,文件阅读器会使用它来读取数据
    filename_queue = tf.train.string_input_producer([tfRecord_path], shuffle=True)
    # 新建一个reader
    reader = tf.TFRecordReader()
    # 解序列化,标签和图片的键名应该和制作tfrecords的键名相同,其中标签给出几分类
    _, serialized_example =reader.read(filename_queue)
    # 将tf.train.Example协议内存块(protocol buffer)解析为张量
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label':tf.FixedLenFeature([10], tf.int64),
                                           'img_raw':tf.FixedLenFeature([],tf.string)
                                       })
    # 将 img_raw 字符串转化为8位无符号整形
    img = tf.decode_raw(features['img_raw'],tf.uint8)
    img.set_shape([784])
    # 注意这里把图片像素值除以255了 和mnist数据集是0到1之间的数对应
    img = tf.cast(img, tf.float32) * (1. / 255)
    
    label = tf.cast(features['label'], tf.float32)

    return  img, label


def get_tfrecord(num,isTrain=True):
    if isTrain:
        tfRecord_path = tfRecord_train
    else:
        tfRecord_path = tfRecord_test
    img, label = read_tfRecord(tfRecord_path)
    #随机读取一个batch的数据
    img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                 batch_size=num,
                                                 num_threads=2,
                                                 capacity=1000,
                                                 min_after_dequeue=700)
    # 返回的图片和标签为随机抽取的batch_size组
    return  img_batch, label_batch


def main():
    generate_tfRecord()


if __name__ == "__main__":
    main()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值