Tensorflow-tfrecord数据

使用的数据:https://download.pytorch.org/tutorial/hymenoptera_data.zip

1、图像—>tfrecode

#!/usr/bin/python3
# -*- coding: UTF-8 -*-

import tensorflow as tf
import glob
from itertools import groupby
from collections import defaultdict
from PIL import Image
import numpy as np

# 将满足目录的所有.jpg文件的路径放置在image_filenames列表中
# image_filenames存放所有满足条件的jpg的路径
image_filenames = glob.glob("./hymenoptera_data/*/*.jpg")  # ==> <class 'list'>

sess = tf.InteractiveSession()

training_dataset = defaultdict(list)
testing_dataset = defaultdict(list)

# Split up the filename into its breed and corresponding filename. The breed is found by taking the directory name
# 将文件名分解为品种和相应的文件名(文件对应的路径),品种对应文件夹名称(作为标签)
image_filename_with_breed = map(lambda filename: (filename.split("/")[-2], filename), image_filenames)  # Linux "/"

# Group each image by the breed which is the 0th element in the tuple returned above
for dog_breed, breed_images in groupby(image_filename_with_breed, lambda x: x[0]):
    # Enumerate each breed's image and send ~20% of the images to a testing set
    for i, breed_image in enumerate(breed_images):
        if i % 5 == 0:
            testing_dataset[dog_breed].append(breed_image[1])  # dog_breed对应文件名,即标签,breed_image[1]对应jpg的路径
        else:
            training_dataset[dog_breed].append(breed_image[1])

    # Check that each breed includes at least 18% of the images for testing
    breed_training_count = len(training_dataset[dog_breed])
    breed_testing_count = len(testing_dataset[dog_breed])

    assert round(breed_testing_count / (breed_training_count + breed_testing_count),
                 2) > 0.18, "Not enough testing images."


# 图像--->tfrecode
def write_records_file(dataset, record_location):
    """
    Fill a TFRecords file with the images found in `dataset` and include their category.

    Parameters
    ----------
    dataset : dict(list)
      Dictionary with each key being a label for the list of image filenames of its value.
    record_location : str
      Location to store the TFRecord output.
    """
    writer = None

    # Enumerating the dataset because the current index is used to breakup the files if they get over 100
    # images to avoid a slowdown in writing.
    # 枚举dataset,因为当前索引用于对文件进行划分,每隔100幅图像,训练样本的信息就被写入到一个新的Tfrecode文件中,以加快操作的进程
    current_index = 0
    for breed, images_filenames in dataset.items():
        for image_filename in images_filenames:
            if current_index % 10 == 0:
                if writer:
                    writer.close()

                record_filename = "{record_location}-{current_index}.tfrecords".format(
                    record_location=record_location,
                    current_index=current_index)

                writer = tf.python_io.TFRecordWriter(record_filename)
            current_index += 1

            '''
            # 方法一,使用PIL
            try:
                image=Image.open(image_filename)
                image=image.convert('L') #转成灰度图
                image=image.resize((250,151))
            except:
                print(image_filename)
                continue

            image_bytes = sess.run(tf.cast(np.array(image), tf.uint8)).tobytes()
            '''

            # 方法二、使用tf.image.decode_jpeg
            # 在ImageNet的狗的图像中,有少量无法被Tensorflow识别的JPEG的图像,利用try/catch可以将这些图像忽略
            try:
                image_file = tf.read_file(image_filename)
                image = tf.image.decode_jpeg(image_file)

                # 转换成灰度图可以减少处理的计算量和内存占用,但这不是必须的
                grayscale_image = tf.image.rgb_to_grayscale(image)  # 转成灰度
                resized_image = tf.image.resize_images(grayscale_image, (250, 151))  # 图像大小固定为 250x151
# resized_image=tf.image.resize_image_with_crop_or_pad(grayscale_image,250,151)

                # 这里之所以使用tf.cast,是因为虽然尺寸更改后的图像数据类型是浮点型,但RGB尚未转换到[0,1)区间内
                image_bytes = sess.run(tf.cast(resized_image, tf.uint8)).tobytes()
            except:
                print(image_filename)
                continue

            # https://en.wikipedia.org/wiki/One-hot
            # 将标签按字符串存储较高效,推荐的做法是将其转换为整数索引或独热编码的秩1张量

            #
            '''
            image_label = tf.case({tf.equal(breed, tf.constant('n02085620-Chihuahua')): lambda: tf.constant(0),
                              tf.equal(breed, tf.constant('n02096051-Airedale')): lambda: tf.constant(1),
                              }, lambda: tf.constant(-1), exclusive=True)

            image_label = sess.run(image_label)
            '''
            image_label = breed.encode("utf-8")


            example = tf.train.Example(features=tf.train.Features(feature={
                'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_label])),
                # 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[image_label])),
                'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))
            }))

            writer.write(example.SerializeToString())
    if writer:
        writer.close()


if __name__ == "__main__":
    write_records_file(training_dataset, "./output/training-images/training-images")
    write_records_file(testing_dataset, "./output/testing-images/testing-images")

2、tfrecode—>numpy

import tensorflow as tf
import glob
# from itertools import groupby
# from collections import defaultdict
# from PIL import Image
# import numpy as np


# Load Images
def load_images_from_tfrecord(tfrecord_file):
    filename_queue = tf.train.string_input_producer(
        tf.train.match_filenames_once(tfrecord_file)) # 加载多个Tfrecode文件
    reader = tf.TFRecordReader()
    _, serialized = reader.read(filename_queue)

    features = tf.parse_single_example(
        serialized,
        features={
            'label': tf.FixedLenFeature([], tf.string),
            # 'label': tf.FixedLenFeature([], tf.int64),
            'image': tf.FixedLenFeature([], tf.string),
        })

    record_image = tf.decode_raw(features['image'], tf.uint8)

    # Changing the image into this shape helps train and visualize the output by converting it to
    # be organized like an image.
    # 修改图像的形状有助于训练和输出的可视化
    image = tf.reshape(record_image, [250, 151, 1])

    label = tf.cast(features['label'], tf.string)
    # label = tf.cast(features['label'], tf.int64)


    # label string-->int 0,1 标签
    label = tf.case({tf.equal(label, tf.constant('n02085620-Chihuahua')): lambda: tf.constant(0),
                            tf.equal(label, tf.constant('n02096051-Airedale')): lambda: tf.constant(1),
                            }, lambda: tf.constant(-1), exclusive=True)


    min_after_dequeue = 10
    batch_size = 3
    capacity = min_after_dequeue + 3 * batch_size
    image_batch, label_batch = tf.train.shuffle_batch(
        [image, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue)


    '''
    # Find every directory name in the imagenet-dogs directory (n02085620-Chihuahua, ...)
    labels = list(map(lambda c: c.split("\\")[-2], glob.glob(imagepath))) # 找到目录名(标签) linux使用 "/"

    # Match every label from label_batch and return the index where they exist in the list of classes
    # 匹配每个来自label_batch的标签并返回它们在类别列表中的索引
    train_labels = tf.map_fn(lambda l: tf.where(tf.equal(labels, l))[0,0:1][0], label_batch, dtype=tf.int64)
    '''

    # Converting the images to a float of [0,1) to match the expected input to convolution2d
    # 将图像转换为灰度值位于[0,1)的浮点类型,
    float_image_batch = tf.image.convert_image_dtype(image_batch, tf.float32)
    return float_image_batch,label_batch




if __name__=="__main__":
    img_batch,label_batch=load_images_from_tfrecord("output/training-images/*.tfrecords")
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            while not coord.should_stop():
                for i in range(100):
                    val, l = sess.run([img_batch, label_batch])
                    if i%5==0:
                        print(val.shape, l.shape,l)
                else:
                    break
        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
        finally:
            coord.request_stop()
        coord.join(threads)
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值