Tensorflow读取数据1

原文地址: http://blog.csdn.net/u010911921/article/details/70577697

这段一直在用Tensorflow来做深度学习上的相关工作,然后对Tensorflow读取数据的方式进行实现。特地总结一下。首先是读取二进制图片数据,这里采用的是CIFAR-10的二进制数据

## 1.CIFAR-10数据集 CIFAR-10数据集合是包含60000张`32*32*3`的图片,其中每个类包含6000张图片,总共10类。在这60000张图片中50000张是训练集合,10000张是测试集合。

其中二进制的图片保存的格式如下所示:

2.Tensorflow读取数据

从Tensorflow的官网可以看到从文件中读取数据的流程主要是一下步骤:

  1. The list of filenames
  2. (Optional) filename shuffling
  3. (Optional) epoch limit
  4. Filename queue
  5. A Reader for the file format
  6. A decoder for a record read by the reader
  7. (Optional) preprocessing
  8. Example queue

按照这样一个流程,首选应该将CIFAR-10的训练集和测试集合,生成文件名列表,然后在讲这个文件名列表传递给tf.train.string_input_producer函数创建一个用于保存文件名称的FIFO的队列,最后用tensor flow产生的reader从队列中读取数据。当reader读到数据就需要用tf.decode_raw函数对读取到的二进制数进行解码。

结束了上述操作,下面就需要采用另一个queue去batch together examples来为训练和测试提供数据。采用tf.train.shuffle_batch将上面生成的imagelabel传入函数即可完成。

3.开始训练

tf.train.shuffle_batch生成batch以后就开始利用tf.train.start_queue_runners函数启动队列,然后开始整个计算图,官网给的建议是如下形式:

init_op = tf.global_variables_initializer()
with tf.Session as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess= sess,coord = coord)
    try:
        while not coord.should_stop():
            #run training steps or whatever
            sess.run(train_op)
    except tf.errors.OutOfRangeError:
        print('Done training --epoch limit reached')
    finally:
        # when done,ask the threads to stop
        coord.request_stop()
    coord.join(threads)

4.代码实现

在神经网络的训练中由于每训练k步以后就会对网络进行一次测试,所以需要在上述步骤中,增加动态选择文件名称队列这样一个过程,可以由tf.QueueBase.from_list函数进行实现,然后reader从返回的文件名称队列中读取数据。

整个过程的实现如下所示:

#!/usr/bin/env python3
# --*-- encoding:utf-8 --*--

import tensorflow as tf
import numpy as np
import os

def read_cifar10(data_dir,is_traing,batch_size,shuffle):
    """

    :param data_dir:数据保存路径
    :param is_traing:True从训练集获取数据,False从测试集获取数据
    :param batch_size:  batch_size的大小
    :param shuffle: bool,是否进行shuffle操作
    :return:
    """
    img_width = 32
    img_height = 32
    img_depth = 3
    label_bytes = 1
    img_bytes = img_height * img_width *img_depth



    with tf.name_scope("input") as scope:
        #训练集合的文件列表
        train_filenames = [os.path.join(data_dir,
                                        'data_batch_%d.bin'%ii) for ii in np.arange(1,6)]
        #测试集合的文件列表
        val_filenames = [os.path.join(data_dir,'test_batch.bin')]

        #训练集和测试集合的文件名称队列
        train_queue = tf.train.string_input_producer(train_filenames)
        val_queue = tf.train.string_input_producer(val_filenames)

        #挑选文件队列,实现training的过程中测试
        queue_select = tf.cond(is_traing,
                               lambda :tf.constant(0),
                               lambda :tf.constant(1) )
        queue = tf.QueueBase.from_list(queue_select,[train_queue,val_queue])

        #从队列中读取固定长度的数据
        reader = tf.FixedLengthRecordReader(label_bytes+img_bytes)
        key,value = reader.read(queue)
        recode_bytes = tf.decode_raw(value,tf.uint8)

        #获取label
        label = tf.slice(recode_bytes,[0],[label_bytes])
        label = tf.cast(label,tf.int32)

        #获取image
        image_raw = tf.slice(recode_bytes,[label_bytes],[img_bytes])
        image_raw = tf.reshape(image_raw,[img_depth, img_height, img_width])
        image = tf.transpose(image_raw,[1,2,0])

        image = tf.cast(image,tf.float32)

        #对每一张图片进行标准化操作,可选操作此处可以进行对图片的各种操作
        image = tf.image.per_image_standardization(image)

        if shuffle:
            images, label_batch= tf.train.shuffle_batch([image,label],
                                                   batch_size=batch_size,
                                                   num_threads=16,
                                                   capacity=512+3*batch_size,
                                                   min_after_dequeue=512,
                                                   allow_smaller_final_batch=True)
        else:
            images, label_batch = tf.train.batch([image, label],
                                            batch_size=batch_size,
                                            num_threads=16,
                                            capacity=512 + 3*batch_size,
                                            allow_smaller_final_batch=True)
        label_batch = tf.cast(label_batch,tf.int32)

        return images,label_batch

整个过程是采用VGG-16的网络模型进行训练的,在迭代16000次,tensorboard展示的结果如图所示:

code下载地址https://github.com/ZhichengHuang/LearnTensorflowCode

参考资料:

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
TensorFlow是一个开源的机器学习框架,通常用于创建神经网络模型。在训练模型之前,需要准备好数据集,本文将介绍如何使用TensorFlow读取数据TensorFlow提供了多种读取数据的方法,其中最常用的是使用tf.data模块。首先,我们需要定义一个数据集对象,并通过读取文件的方式将数据加载进来。TensorFlow支持多种文件格式,如csv、txt、json、tfrecord等,可以根据自己的需求选择合适的格式。 加载数据后,我们可以对数据进行一些预处理,比如做数据增强、进行归一化等操作。预处理完数据后,我们需要将数据转化为张量类型,并将其打包成batch。通过这种方式,我们可以在每次训练中同时处理多个数据。 随后,我们可以使用tf.data.Dataset中的shuffle()函数打乱数据集顺序,防止模型只学习到特定顺序下的模式,然后使用batch()函数将数据划分成批次。最后,我们可以使用repeat()函数让数据集每次可以被使用多次,达到更好的效果。 在TensorFlow中,我们可以通过输入函数将数据集传入模型中,使模型能够直接从数据集中读取数据。使用输入函数还有一个好处,即能够在模型训练时动态地修改数据的内容,特别是在使用esimator模块进行模型训练时,输入函数是必须要的。 总结一下,在TensorFlow读取数据的流程如下:定义数据集对象-读取文件-预处理数据-打包数据为batch-打乱数据集-划分批次数据-重复数据集-使用输入函数读取数据。 在实际应用过程中,我们还可以通过其他方式来读取数据,如使用numpy、pandas等工具库,也可以自定义数据集类来处理数据。无论使用何种方式,读取数据都是机器学习训练中重要的一步,需要仔细处理。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值