利用Tensorflow构建自己的图片数据集TFrecords

相信很多初学者和我一样,虽然用了那么久的tensorflow,也尝试了很多的实例,但那些实例基本都是直接利用官方文档现成的MNIST和cifar_10数据库,而一旦需要自己构建数据集时,完全不知道该如何制作并输入自己改的数据。另外,虽然也有一些人提供了相关的操作,但是总是或多或少存在各种各样的问题。今天给大家分享我的Tensorflow制作数据集的学习历程。 TensorFlow提供了标准的TFRecord 格式,而关于 tensorflow 读取数据, 官网也提供了3中方法 :
1 Feeding: 在tensorflow程序运行的每一步, 用python代码在线提供数据
2 Reader : 在一个计算图(tf.graph)的开始前,将文件读入到流(queue)中
3 在声明tf.variable变量或numpy数组时保存数据。受限于内存大小,适用于数据较小的情况

特此声明:初次写博客,如有问题,如有问题多体谅;另外文本参考了下面的博客(提供链接如下),因而读者可结合两者取齐所需。

点击打开链接http://blog.csdn.net/miaomiaoyuan/article/details/56865361

在本文,主要介绍第二种方法,利用tf.record标准接口来读入文件

第一步,准备数据

先在网上下载一些不同类的图片集,例如猫、狗等,也可以是同一种类,不同类型的,例如哈士奇、吉娃娃等都属于狗类;此处笔者预先下载了哈士奇、吉娃娃两种狗的照片各20张,并分别将其放置在不同文件夹下。如下:


第二步,制作TFRecord文件

注意:tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签 如在本例中,只有0,1 两类

#-----------------------------------------------------------------------------
#encoding=utf-8
import os
import tensorflow as tf
from PIL import Image

cwd = 'E:/train_data/picture_dog//' 
classes = {'husky','jiwawa'}


#制作TFRecords数据
def create_record():
    writer = tf.python_io.TFRecordWriter("dog_train.tfrecords")
    for index, name in enumerate(classes):
        class_path = cwd +"/"+ name+"/"
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((64, 64))
            img_raw = img.tobytes() #将图片转化为原生bytes
            print (index,img_raw)
            example = tf.train.Example(
               features=tf.train.Features(feature={
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
               }))
            writer.write(example.SerializeToString())
    writer.close()
#-------------------------------------------------------------------------

将上面的代码编辑完成后,点击运行,就会生成一个dog_train.TFRecords文件,如下图所示:


TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

第三步,读取TFRecord文件

#-------------------------------------------------------------------------
cwd = 'E:/train_data/picture_dog//' 
#读取二进制数据

def read_and_decode(filename):
    # 创建文件队列,不限读取的数量
    filename_queue = tf.train.string_input_producer([filename])
    # create a reader from file queue
    reader = tf.TFRecordReader()
    # reader从文件队列中读入一个序列化的样本
    _, serialized_example = reader.read(filename_queue)
    # get feature from serialized example
    # 解析符号化的样本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string)
        })
    label = features['label']
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)
    img = tf.reshape(img, [64, 64, 3])
    #img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(label, tf.int32)
    return img, label
#--------------------------------------------------------------------------    
一个Example中包含Features,Features里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List。另外,需要我们注意的是:feature的属性“label”和“img_raw”名称要和制作时统一 ,返回的img数据和label数据一一对应。

第四步,TFRecord的显示操作

如果想要检查分类是否有误,或者在之后的网络训练过程中可以监视,输出图片,来观察分类等操作的结果,那么我们就可以session回话中,将tfrecord的图片从流中读取出来,再保存。因而自然少不了主程序的存在。

#---------主程序----------------------------------------------------------
if __name__ == '__main__':
    create_record()
    batch = read_and_decode('dog_train.tfrecords')
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    
    with tf.Session() as sess: #开始一个会话  
        sess.run(init_op)  
        coord=tf.train.Coordinator()  
        threads= tf.train.start_queue_runners(coord=coord)  
        for i in range(40):  
            example, lab = sess.run(batch)#在会话中取出image和label  
            img=Image.fromarray(example, 'RGB')#这里Image是之前提到的  
            img.save(cwd+'/'+str(i)+'_Label_'+str(lab)+'.jpg')#存下图片;注意cwd后边加上‘/’  
            print(example, lab)  
        coord.request_stop()  
        coord.join(threads) 
        sess.close()
#-----------------------------------------------------------------------------
进过上面的一通操作之后,我们便可以得到和tensorflow官方的二进制数据集一样的数据集了,并且可以按照自己的设计来进行。

下面附上该程序的完整代码,仅供参考。

#-----------------------------------------------------------------------------
#encoding=utf-8
import os
import tensorflow as tf
from PIL import Image

cwd = 'E:/train_data/picture_dog//' 
classes = {'husky','jiwawa'}


#制作TFRecords数据
def create_record():
    writer = tf.python_io.TFRecordWriter("dog_train.tfrecords")
    for index, name in enumerate(classes):
        class_path = cwd +"/"+ name+"/"
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((64, 64))
            img_raw = img.tobytes() #将图片转化为原生bytes
            print (index,img_raw)
            example = tf.train.Example(
               features=tf.train.Features(feature={
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
               }))
            writer.write(example.SerializeToString())
    writer.close()
#-------------------------------------------------------------------------

#读取二进制数据

def read_and_decode(filename):
    # 创建文件队列,不限读取的数量
    filename_queue = tf.train.string_input_producer([filename])
    # create a reader from file queue
    reader = tf.TFRecordReader()
    # reader从文件队列中读入一个序列化的样本
    _, serialized_example = reader.read(filename_queue)
    # get feature from serialized example
    # 解析符号化的样本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string)
        })
    label = features['label']
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)
    img = tf.reshape(img, [64, 64, 3])
    #img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(label, tf.int32)
    return img, label
#--------------------------------------------------------------------------    
#---------主程序----------------------------------------------------------
if __name__ == '__main__':
    create_record()
    batch = read_and_decode('dog_train.tfrecords')
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    
    with tf.Session() as sess: #开始一个会话  
        sess.run(init_op)  
        coord=tf.train.Coordinator()  
        threads= tf.train.start_queue_runners(coord=coord)  
        for i in range(40):  
            example, lab = sess.run(batch)#在会话中取出image和label  
            img=Image.fromarray(example, 'RGB')#这里Image是之前提到的  
            img.save(cwd+'/'+str(i)+'_Label_'+str(lab)+'.jpg')#存下图片;注意cwd后边加上‘/’  
            print(example, lab)  
        coord.request_stop()  
        coord.join(threads) 
        sess.close()
#-----------------------------------------------------------------------------
运行上述的完整代码,便可以 将从TFRecord中取出的文件保存下来了。如下图:


每一幅图片的命名中,第二个数字则是 label,吉娃娃都为1,哈士奇都为0;通过对照图片,可以发现图片分类正确。

  • 5
    点赞
  • 83
    收藏
    觉得还不错? 一键收藏
  • 14
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值