TensorFlow学习笔记04

TFRecord格式

TFRecord格式可以统一不同的原始数据格式,并更加有效地管理不同的属性

格式介绍

其数据通过tf.train.Example Protocol Buffer格式存储:

message Example{
	Feature features = 1;
	};
message Features{
	map<string, Feature> feature = 1;
	};
message Feature{
	oneof kind{
		BytesList bytes_list = 1;
		FloatList float_list = 2;
		Int64List int64_list = 3;
		}
	;}

字典:属性名称string——>取值Feature
属性名称:字符串
属性取值:字符串、实数列表、整数列表

样例程序

  • TFRecord写操作:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


mnist = input_data.read_data_sets("", dtype=tf.uint8, one_hot=True)
images = mnist.train.images
labels = mnist.train.labels
pixels = images.shape[1]
num_examples = mnist.train.num_examples

filename = ""
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
    # 图像矩阵转化为一个字符串
    image_raw = images[index].toString()
    example = tf.train.Example(features=tf.train.Feature(feature={
        'pixels': _int64_feature(pixels),
        'labels': _int64_feature(np.argmax(labels[index])),
        'image_raw': _bytes_feature(image_raw)
    }))
    writer.write(example.SerializeToString())
writer.close()

  • TFRecord读操作:
import tensorflow as tf

reader = tf.TFRecordReader()
# 队列维护输入文件列表
filename_queue = tf.train.string_input_producer([""])
# 从文件中读取一个样例
_, serialized_example = reader.read(filename_queue)
# 解析读入的一个样例
features = tf.parse_single_example(
    serialized_example,
    features={
        'image_raw': tf.FixedLenFeature([], tf.string),
        'pixels': tf.FixedLenFeature([], tf.int64),
        'labels': tf.FixedLenFeature([], tf.int64),
    }
)
# 将字符串解析为对应的像素数组
image = tf.decode_raw(features['image_raw'], tf.uint8)
lable = tf.cast(features['labels'], tf.int32)
pixels = tf.cast(features['pixels'], tf.int32)

sess = tf.Session()
# 启动多线程处理数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

for i in range(10):
    print(sess.run([image, lable, pixels]))

多线程输入数据处理框架

经典的输入数据处理流程:

Created with Raphaël 2.2.0 开始 指定原始数据的文件列表 创建文件列表队列 从文件中读取数据 数据预处理 整理成Batch作为神经网络输入 结束

其中 队列 是Tensorflow多线程输入数据处理框架的基础。

队列

队列与变量类似,修改队列状态的操作主要有Enqueue、EnqueueMany、Dequeue。
TensorFlow中提供两种队列:FIFOQueue(先进先出队列)和RandomShuffleQueue(元素打乱队列,每次从队列中随机选择一个出队,在神经网络中希望每次使用的训练数据尽量随机,所以常用此队列)
操作队列例子:

import tensorflow as tf

# 先进先出队列创建(队列中最多可保存两个元素)
q = tf.FIFOQueue(2, "int32")
# 队列初始化(注意逗号不要忘记,这样既是列表,又满足tf接收的是tensor张量的需求)
init = q.enqueue_many(([0, 10],))
# 出队
x = q.dequeue()
y = x + 1
# 入队
q_inc = q.enqueue([y])

with tf.Session() as sess:
    init.run()
    for _ in range(5):
        v, _ = sess.run([x, q_inc])
        print(v)

队列不仅是一种数据结构,还是异步计算张量的一个重要机制。多线程可同时向一个队列中写元素或读元素。

多线程

TensorFlow提供两个类完成多线程协同功能:tf.Coordinator(协同多个线程一起停止)和tf.QueueRunner(启动多个线程来操作同一队列)

import tensorflow as tf
import numpy as np
import threading
import time


def MyLoop(coord, work_id):
    while not coord.should_stop():
       if np.random.rand() < 0.1:
            print("stop from id: %d" % work_id)
            coord.request_stop()
       else:
            print("working on id: %d" % work_id)
       time.sleep(1)


# 声明类
coord = tf.train.Coordinator()
# 创建5个线程(将coord类传入每个线程中)
threads = [
    threading.Thread(target=MyLoop, args=(coord, i, )) for i in range(5)
]
# 启动所有线程
for t in threads:
    t.start()
# 等待所有线程退出
coord.join(threads)

综合TFRecord和队列——队列管理

tf.train.match_filenames_once函数:获取符合正则表达式的所有文件
tf.train.string_input_producer函数:对文件列表进行有效管理
注意:当一个输入队列中的所有文件都被处理完后,会将初始化时提供的文件列表中的文件全部重新加入队列。num_epochs参数限制加载初始文件列表的最大轮数。

import tensorflow as tf

files = tf.train.match_filenames_once("./data.tfrecords-*")
filename_queue = tf.train.string_input_producer(files, shuffle=False)
# 读取并解析一个样本
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
    serialized_example,
    features={
        'i': tf.FixedLenFeature([], tf.int64),
        'j': tf.FixedLenFeature([], tf.int64)
    })

with tf.Session() as sess:
    # 使用tf.train.match_filename_once()函数时需要初始化一些变量
    tf.local_variables_initializer().run()
    # 启动线程
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(6):
        print(sess.run([features['i'], features['j']]))
    # 停止所有线程
    coord.request_stop()
    coord.join(threads)

注意:TensorFlow提供了一套更高层的数据处理框架——“数据集”,tf.data核心组件。
并正式推荐使用数据集作为输入数据的首选框架。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值