生成tfrecod类型数据集

        最近在做tensorflow分布式训练时,遇到一个问题,就是在分布式文件系统中,tensorflow读取jpeg数据很慢,因为有十几万的图片,导致要读半个小时以上,所以想提高数据读取速度,就把jpeg数据转换成tfrecord类型数据。我已celeba数据为例,二十多万张图片,转换成tfrecord类型的数据后,读取这些数据只要30秒左右。

        其实在tensorflow的models的代码仓库里,有好多模型是把数据集转换为转换为tfrecord的,例如ImageNet 里的restnet_v2_50,cafir10等,有兴趣的可以去找了看看,地址:https://github.com/tensorflow/models

        下面我介绍两种生成tfrecord数据的方法,一种生成的要比原始数据大10倍左右,另一种和原始数据差不多,但是数据格式不一样,可以根据自己的需要选择。以转换celeba数据为例。

        第一种,使用open-cv读取数据,这种生成的要比原始数据大10倍左右。

使用的writer为:

                writer = tf.python_io.TFRecordWriter(train_filename)

可以选择是否加压缩,压缩有亮总可选GZIP和GLIB:

                options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)

                writer = tf.python_io.TFRecordWriter(train_filename, options=options)

使用cv2读取图片数据:

            def load_image(addr):
                    # read an image and resize to (224, 224)
                    # cv2 load images as BGR, convert it to RGB
                    img = cv2.imread(addr)
                    img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    img = img.astype(np.float32)

                    return img

获取图片数据和lable:

                 # Load the image
                img = load_image(addrs_train[i])

                label = label_list[i]

创建一个feature,feature有三种类型,int64_list,bytes_list,float_list三种:

                def _int64_feature(value):
                      return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
                def _bytes_feature(value):
                      return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

                feature = {'train/label': _int64_feature(label),

                       'train/image': _bytes_feature(tf.compat.as_bytes(img.tostring()))}

创建example protocol buffer,写文件,写完后关闭writer:

                example = tf.train.Example(features=tf.train.Features(feature=feature))

                writer.write(example.SerializeToString())

                writer.close()

第二种就是使用tensorflow里的函数tf.gfile.FastGFile,来读取数据。

步骤和上面基本一样,在读取图片数据有区别,

直接读取数据:

                image_data = tf.gfile.FastGFile(addrs_train[i], 'r').read()

获取图片的高和宽:

                height, width = image_reader.read_image_dims(sess, image_data)

其他步骤和上面类似,就不多写了。

这两种写数据的方式不同,在后面读取这两种数据也是有区别的,具体可以看源码,

代码地址:https://github.com/zzdgit/my_project/tree/master/Machine-Learning

里面的celeba.py 和celeba1.py就是tfrecord转换代码。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值