tensorflow 数据读取

TensorFlow常见的数据读取方式分为3类:

1. placeholder+feeding

2. QueueRunner

3. dataset

其中第三种方式是目前的主流,但前2种方法,尤其是第二种方法我认为了解一下对于理解他人的代码是有帮助的:

 

1. placeholder+feeding

简单地说就是用Python程序处理好相关的数据,然后将整理好的数据通过placeholder+feed_dict的方式直接将数据喂给模型,这种方法感觉比较慢,适合小模型,小数据:

def data_in_fn():
    return np.array(xxx)


with tf.Session() as sess:
    sess.run(train_op, feed_dict={in_:data_in_fn()})

2. QueueRunner

这个方案是基于队列的形式构建的输入输入通道,简单地说就是构建一个文件名队列,一个内存队列,如下图所示,文件名队列中存储所有的文件名数据,如果设置了epoch,shuffle等信息,则会体现在文件名队列中,比如有3个文件,名字为A,B,C,当epoch=2,shuffle=True时,文件名队列可能的顺序为CBAACB这样,内存队列始终根据FIFO的原则,从队列的头部读取下一个文件的名字,然后放到内存中。

构建的步骤如下:

首先,文件列表就是需要处理的文件的名字,可以使用Python的list,然后使用接口  tf.train.string_input_producer   将该列表创建成一个FIFO的队列,此接口可以设置是否要shuffle文件名列表,设置想要多少个epoch,但要注意,shuffle只是调整一个epoch的文件顺序,比如文件有A,B,C三个,想要2个epoch,那么shuffle之后不可能出现ABBCCA这种情况情况,因为对于第一个epoch中,出现了2个B。

第二,TensorFlow给出了几个我们常用的格式的reader和decode,分别处理:text(CSV)格式,二进制文件(读取固定长度字节数信息),tfrecord格式的文件。三组接口分别是:

filename_queue = ['文件1', '文件2']

# 处理text和CSV  read 方法每执行一次,会从文件中读取一行。然后 decode_csv 将读取的内容解析成一个Tensor列表
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)
# 处理二进制
record_bytes = label_bytes + image_bytes

# 读取固定长度字节数信息(针对bin文件使用FixedLengthRecordReader读取比较合适)
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
key, value = reader.read(filename_queue)
 
# 字符串转为张量数据decode_raw
record_bytes = tf.decode_raw(value, tf.uint8)
 
# tf.stride_slice(data, begin, end):从张量中提取数据段,并用cast进行数据类型转换
label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)

‘’’
首先说一下这个文件是怎么来的,首先将数据放到EXAMPLE protocol buffer中,然后TF会将其进行序列化,然后调用writer接口,将数据写下来即可。
‘’‘
# 处理tfrecord格式的文件
filename_queue = tf.train.string_input_producer([filename], num_epochs=3, num_epochs=num_epochs)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
image, label = decode(serialized_example)

# 其中的decode是一个Python函数,用于处理每一个serialized_example,比如
def decode():
    features = tf.parse_single_example(example_proto, features={
        'feature1': tf.VarLenFeature(tf.float32),
        'feature1_shape': tf.FixedLenFeature([1], tf.int64),
        'label': tf.FixedLenFeature([1], tf.int64),
    })
    .... # 其他处理
    return features, labels

第三,预处理。预处理是独立于模型的,主要是对数据做一些相关的处理,比如加噪、畸变等。这一步骤可省略。

第四,在训练之前,需要调用接口  tf.train.start_queue_runners  才能启动填充队列的线程,具体这里参见博客

 

                           图2 queue的示意图(摘自博客https://blog.csdn.net/u014061630/article/details/80712635

我们训练网络的时候都要设置一个batch有多少个数据,参见下面的代码,其中capacity表示再做shuffle的时候,是从capacity中随机选择batch_size个数据。如果不想要shuffle,则使用接口tf.train.batch。

def read_my_file_format(filename_queue):
  reader = tf.SomeReader()  # reader部分
  key, record_string = reader.read(filename_queue)
  example, label = tf.some_decoder(record_string)  # decoder部分
  processed_example = some_processing(example)   # 其他预处理
  return processed_example, label

def input_pipeline(filenames, batch_size, num_epochs=None):
  filename_queue = tf.train.string_input_producer(
      filenames, num_epochs=num_epochs, shuffle=True)
  example, label = read_my_file_format(filename_queue)

  min_after_dequeue = 10000
  capacity = min_after_dequeue + 3 * batch_size
  example_batch, label_batch = tf.train.shuffle_batch(
      [example, label], batch_size=batch_size, capacity=capacity,
      min_after_dequeue=min_after_dequeue)
    
  # 或者使用  tf.train.batch  不对数据进行shuffle
  # example_batch, label_batch = tf.train.batch([example, label], batch_size=batch_size)
  return example_batch, label_batch

最后,一般来说,使用  tf.train.Coordinator  配合来管理线程。

该方案总体流程代码如下:

# Create a session for running operations in the Graph.
sess = tf.Session()

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    while not coord.should_stop():
        # Run training steps or whatever
        sess.run(train_op)

except tf.errors.OutOfRangeError:
    print('Done training -- epoch limit reached')
finally:
    # When done, ask the threads to stop.
    coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()

 

3. dataset

使用dataset,关键就是明确2个概念:如何生成dataset?如何使用取dataset元素的iterator。

tf.data.Dataset:当中的每一个元素就是训练数据的一个数据,这个数据当然是可以包含多个tensor的。

tf.data.Iterator:使用接口get_next(),每次从dataset中取出一个元素。

3.1 如何生成tf.data.Dataset:

1. 如果数据在内存中使用接口tf.data.Dataset.from_tensors() 或 tf.data.Dataset.from_tensor_slices()。如果数据在tfrecord文件中,使用接口tf.data.TFRecordDataset()。得到的每个元素有2个属性,一个是数据类型,一个是形状shape,跟别可以使用Dataset.output_types 和 Dataset.output_shapes 来查看相应的属性:

a = tf.reshape(tf.range(100), shape=(10, 2, 5))
dataset1 = tf.data.Dataset.from_tensors(a)
print(dataset1.output_types, dataset1.output_shapes)
# <dtype: 'int32'> (10, 2, 5)

dataset2 = tf.data.Dataset.from_tensor_slices(a)
print(dataset2.output_types, dataset2.output_shapes)
# <dtype: 'int32'> (2, 5)

b = tf.reshape(tf.range(120), shape=(10, 2, 6))
dataset3 = tf.data.Dataset.from_tensor_slices({'a': a, 'b': b})
print(dataset3.output_types, dataset3.output_shapes)
# {'a': tf.int32, 'b': tf.int32} {'a': TensorShape([Dimension(2), Dimension(5)]), 'b': TensorShape([Dimension(2), Dimension(6)])}

dataset4 = tf.data.Dataset.from_tensor_slices((a, b))
print(dataset4.output_types, dataset4.output_shapes)
#  (tf.int32, tf.int32) (TensorShape([Dimension(2), Dimension(5)]), TensorShape([Dimension(2), Dimension(6)]))

可以看到,tf.data.Dataset.from_tensor_slices可以接受元组,list,字典等。

如果是从tfrecord中读取数据:

dataset = tf.data.TFRecordDataset(record_name)

得到的dataset需要进一步解析

这里需要插一段具体如何读/写tfrecord格式的文件,上面粗略的说了步骤:将数据放到Example中,然后TF会有序列化的工具(protocol buffer)帮你进行序列化,序列化了以后就可以调用write写下来了(序列化的意思就是说将字典、类对象等结构化的数据转成字符串等序列化的数据,只有这样,才能进行数据的保存和传输)。

3.1.1 .tfrecord格式的文件

参考博客

1. 这个方案光看名字也知道该格式的文件能够跟TF配合的比较好,对于训练模型使用大型训练集该方案目前是最好的。TF提供了丰富的API,能够帮我们轻松读写该文件。

前面说了,把需要存储的数据组织到Example中,那么怎么组织?Example中有什么?见下面的代码:

message Example {
 Features features = 1;
};
 
message Features{
 map<string,Feature> featrue = 1;
};
 
message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};

Example这个消息体里面包含了一系列的属性Features(具体这个消息体怎么去理解,可以当做一个类似结构体的东西,这个主要是谷歌的protocol buffer协议中定义的,我也没有深入了解细节,只要知道怎么根据这个消息体设置内容即可,另外,就是TF的图运算等都是跟这个协议相关的:Python代码构建一个运算图,TF会根据这个协议将这个图序列化,然后后端的c/c++代码才能运算这个图)。

那么Features属性又是什么?可以看到Features消息体中的内容是一个map,key是字符串,value是Feature,而最后一步Feature内部可以包含3种数据类型的数据:byte型(存储图像、音频、字符串内容等)、int型、float型。所以,tfrecord文件只能保存3种类型的数据。其实为了更丰富的保存数据,还有FeatureList和FeatureLists消息体,但他们其实就是消息体中的内容是map还是list的区别,具体他们之间的关系可以如下图:

code表示:

def fun():
    # 注意:这里的字符'1'等如果前面不加那个b,会报错
    bl = tf.train.BytesList(value=[b'1', b'2', b'3'])
    fl = tf.train.FloatList(value=[1.0, 2.0, 3.0])
    il = tf.train.Int64List(value=[1, 2, 3, 4])

    f_bl = tf.train.Feature(bytes_list=bl)
    f_il = tf.train.Feature(int64_list=il)
    f_fl = tf.train.Feature(float_list=fl)

    k_fs = tf.train.Features(feature={'first': f_bl,
                                      'second': f_il,
                                      'three': f_fl})

    l_fs = tf.train.FeatureList(feature=[f_bl, f_il, f_fl])
    k_l_fs = tf.train.FeatureLists(feature_list={'key1': l_fs})

    # 将数据放到example中
    data1 = tf.train.Example(features=k_fs)
    # 这里附加了一个SequenceExample,其实跟Example大同小异,只不过它内部多了一个存储tf.train.FeatureLists数据的参数,我平时用的不多
    data2 = tf.train.SequenceExample(context=k_fs, feature_lists=k_l_fs)

    # 写文件
    writer = tf.python_io.TFRecordWriter('1.tfrecords')
    # 这里使用SerializeToString接口进行序列化
    writer.write(data1.SerializeToString())
    writer.close()

    writer = tf.python_io.TFRecordWriter('2.tfrecords')
    writer.write(data2.SerializeToString())
    writer.close()

这样就生成了1.tfrecords和2.tfrecords文件。下面看一下如何读取:前面说了,读取时候调用接口dataset = tf.data.TFRecordDataset(record_name)即可读取,但这不是重点,重点是读取的dataset,如何解析?简单的说就是当时存的时候都设置了内部某个字段,如‘first’这个字段,存了什么内容,这个内容是什么数据类型的,那么读取的时候也是要按照当时的key字段恢复出来。

我们在存储数据的时候,比如存的是一个NLP中一个句子中的每个字在字典中的索引,这是int类型,但是每个句子的长度都是不一样的,这个时候解析的时候要用到一个接口tf.VarLenFeature(),如果存储的是分类模型的label,比如分为3个类,label的值只能取0,1,2。并且每个句子的label就一项,那么使用的接口是tf.FixedLenFeature,其实就是当解析固定大小的数据时使用该接口。

当使用tf.VarLenFeature()时,数据长度未知,所以只要指定解析的数据类型即可,使用该接口返回的结果是稀疏结果,需要调用接口tf.sparse_tensor_to_dense使它变成dense类型。使用tf.FixedLenFeature时,既然是固定长度,那么就要指定数据类型和长度(保存时都变成了1维数组)。

所以,我们解析上面的1.tfrecords的文件结果的代码:

def fun1():
    def _parse_fun(example_proto):
        # 这里是针对每一个example都做如下的解析
        features = tf.parse_single_example(example_proto, features={
            'first': tf.FixedLenFeature([3], tf.string),
            'second': tf.VarLenFeature(tf.int64),
            'three': tf.FixedLenFeature([3], tf.float32),
        })
        second = tf.cast(tf.sparse_tensor_to_dense(features['second']), tf.int32)
        three = features['three']
        first = tf.cast(features['first'], tf.string)
        return first, second, three

    record_name = '1.tfrecords'
    dataset = tf.data.TFRecordDataset(record_name)
    dataset = dataset.map(_parse_fun)

好了,到目前为止我们已经知道怎么生成tf.data.Dataset了:读取tfrecord文件或者直接通过Python数据生成

3.1.2 其他生成dataset的手段:

可以读取二进制文件  tf.data.FixedLengthRecordDataset  构建;

基于  tf.data.TextLineDataset  构建,主要是读取TXT或者CSV格式的文件;

 

3.2 tf.data.Iterator

迭代器按照这篇博客介绍的,有四种方案,但目前来说,感觉最常见的就是第一种方案:单次迭代器。所以此处只记录该迭代器的使用,其他迭代器的使用请参考上面的博客。

dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
  value = sess.run(next_element)
  assert i == value

当dataset内的元素取完的时候,如果再次在继续取数据,就会报一个OutOfRangeError的错误。

 

3.3 dataset的其他操作

3.3.1 dataset.map(fun)操作

通过函数fun,将dataset的每个元素进行一定的运算,比如上面读取了tfrecord格式的文件之后,对每一个example进行一个解析。

还有一种就是下面这种也是比较常见的:

def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_image(image_string)
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])

# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)

当使用Python的外部函数作为map解析的fun时,需要使用接口tf.py_func.

# tf.py_func
tf.py_func(func, # 一个Python函数
           inp, # 一个Tensor列表
           Tout, # 输出的Tensor的dtype或Tensors的dtype列表
           stateful=True, # 布尔值,输入值相同,输出值就相同,那么就将stateful设置为False
           name=None)
import cv2

# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
  image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE)
  return image_decoded, label

# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
  image_decoded.set_shape([None, None, None])
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
    lambda filename, label: tuple(tf.py_func(
        _read_py_function, [filename, label], [tf.uint8, label.dtype])))
dataset = dataset.map(_resize_function)

3.3.2 dataset.batch()操作

其实就是根据batch_size设置一个batch的数据,操作很简单,dataset.batch(batch_size),设置batch_size即可。

但训练时经常出现不同的数据长度不一致,比如NLP领域,这个时候需要将短的数据进行padding,那么使用dataset.padded_batch()即可,这个接口相对上面复杂一些,需要指定padding的shape和dtype。其中dtype好说,希望补充的数据类型是什么就是什么,但shape需要注意:

#tf.TensorShape([])     表示长度为单个数字
#tf.TensorShape([None]) 表示长度未知的向量
padded_shapes=(tf.TensorShape([None]), tf.TensorShape([]))

比方说分类数据中的label,它只能取一个int数值,padding的shape就是[],表示单个数据,而对于一个音频的特征,一条音频的特征是个二维数组,padding的shape就是[None, None]。

3.3.3 dataset.shuffle()和repeat操作

dataset.repeat(epoch).shuffle(buffer_size=batch_size*2)

 

 

参考文献:

文中的一些代码和部分图都是从以下这些博客中摘取的:

https://blog.csdn.net/u014061630/article/details/80712635

https://zhuanlan.zhihu.com/p/27238630

https://blog.csdn.net/tefuirnever/article/details/90523253

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值