使用自己的图片制作TFrecords

 

先上代码:

# -*-coding: utf-8 -*-

"""

    @Project: tf_record_demo

    @File   : tf_record_batchSize.py

    @Author : panjq

    @E-mail : pan_jinquan@163.com

    @Date   : 2018-07-27 17:19:54

    @desc   : 将图片数据保存为多个record文件

"""

##########################################################################


import tensorflow as tf

import numpy as np

import os

import cv2

import math

import matplotlib.pyplot as plt

import random

from PIL import Image


##########################################################################

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# 生成字符串型的属性

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


# 生成实数型的属性

def float_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def show_image(title, image):
    '''

    显示图片

    :param title: 图像标题

    :param image: 图像的数据

    :return:

    '''

    # plt.figure("show_image")

    # print(image.dtype)

    plt.imshow(image)

    plt.axis('on')  # 关掉坐标轴为 off

    plt.title(title)  # 图像题目

    plt.show()


def load_labels_file(filename, labels_num=1):
    '''

    载图txt文件,文件中每行为一个图片信息,且以空格隔开:图像路径 标签1 标签2,如:test_image/1.jpg 0 2

    :param filename:

    :param labels_num :labels个数

    :return:images type->list

    :return:labels type->list

    '''

    images = []

    labels = []

    with open(filename) as f:

        for lines in f.readlines():

            line = lines.rstrip().split(' ')

            label = []

            for i in range(labels_num):
                label.append(int(line[i + 1]))

            images.append(line[0])

            labels.append(label)

    return images, labels


def read_image(filename, resize_height, resize_width):
    '''

    读取图片数据,默认返回的是uint8,[0,255]

    :param filename:

    :param resize_height:

    :param resize_width:

    :return: 返回的图片数据是uint8,[0,255]

    '''

    bgr_image = cv2.imread(filename)

    if len(bgr_image.shape) == 2:  # 若是灰度图则转为三通道

        print("Warning:gray image", filename)

        bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR)

    rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)  # 将BGR转为RGB

    # show_image(filename,rgb_image)

    # rgb_image=Image.open(filename)

    if resize_height > 0 and resize_width > 0:
        rgb_image = cv2.resize(rgb_image, (resize_width, resize_height))

    rgb_image = np.asanyarray(rgb_image)

    # show_image("src resize image",image)

    return rgb_image


def create_records(image_dir, file, record_txt_path, batchSize, resize_height, resize_width):
    '''

    实现将图像原始数据,label,长,宽等信息保存为record文件

    注意:读取的图像数据默认是uint8,再转为tf的字符串型BytesList保存,解析请需要根据需要转换类型

    :param image_dir:原始图像的目录

    :param file:输入保存图片信息的txt文件(image_dir+file构成图片的路径)

    :param output_record_txt_dir:保存record文件的路径

    :param batchSize: 每batchSize个图片保存一个*.tfrecords,避免单个文件过大

    :param resize_height:

    :param resize_width:

    PS:当resize_height或者resize_width=0是,不执行resize

    '''

    if os.path.exists(record_txt_path):
        os.remove(record_txt_path)

    setname, ext = record_txt_path.split('.')

    # 加载文件,仅获取一个label

    images_list, labels_list = load_labels_file(file, 1)

    sample_num = len(images_list)

    # 打乱样本的数据

    # random.shuffle(labels_list)

    batchNum = int(math.ceil(1.0 * sample_num / batchSize))

    for i in range(batchNum):

        start = i * batchSize

        end = min((i + 1) * batchSize, sample_num)

        batch_images = images_list[start:end]

        batch_labels = labels_list[start:end]

        # 逐个保存*.tfrecords文件

        filename = setname + '{0}.tfrecords'.format(i)

        print('save:%s' % (filename))

        writer = tf.python_io.TFRecordWriter(filename)

        for i, [image_name, labels] in enumerate(zip(batch_images, batch_labels)):

            image_path = os.path.join(image_dir, batch_images[i])

            if not os.path.exists(image_path):
                print('Err:no image', image_path)

                continue

            image = read_image(image_path, resize_height, resize_width)

            image_raw = image.tostring()

            print('image_path=%s,shape:( %d, %d, %d)' % (image_path, image.shape[0], image.shape[1], image.shape[2]),
                  'labels:', labels)

            # 这里仅保存一个label,多label适当增加"'label': _int64_feature(label)"项

            label = labels[0]

            example = tf.train.Example(features=tf.train.Features(feature={

                'image_raw': _bytes_feature(image_raw),

                'height': _int64_feature(image.shape[0]),

                'width': _int64_feature(image.shape[1]),

                'depth': _int64_feature(image.shape[2]),

                'label': _int64_feature(label)

            }))

            writer.write(example.SerializeToString())

        writer.close()

        # 用txt保存*.tfrecords文件列表

        # record_list='{}.txt'.format(setname)

        with open(record_txt_path, 'a') as f:

            f.write(filename + '\n')


def read_records(filename, resize_height, resize_width):
    '''

    解析record文件

    :param filename:保存*.tfrecords文件的txt文件路径

    :return:

    '''

    # 读取txt中所有*.tfrecords文件

    with open(filename, 'r') as f:
        lines = f.readlines()

        files_list = []

        for line in lines:
            files_list.append(line.rstrip())

    # 创建文件队列,不限读取的数量

    filename_queue = tf.train.string_input_producer(files_list, shuffle=False)

    # create a reader from file queue

    reader = tf.TFRecordReader()

    # reader从文件队列中读入一个序列化的样本

    _, serialized_example = reader.read(filename_queue)

    # get feature from serialized example

    # 解析符号化的样本

    features = tf.parse_single_example(

        serialized_example,

        features={

            'image_raw': tf.FixedLenFeature([], tf.string),

            'height': tf.FixedLenFeature([], tf.int64),

            'width': tf.FixedLenFeature([], tf.int64),

            'depth': tf.FixedLenFeature([], tf.int64),

            'label': tf.FixedLenFeature([], tf.int64)

        }

    )

    tf_image = tf.decode_raw(features['image_raw'], tf.uint8)  # 获得图像原始的数据

    tf_height = features['height']

    tf_width = features['width']

    tf_depth = features['depth']

    tf_label = tf.cast(features['label'], tf.int32)

    # tf_image=tf.reshape(tf_image, [-1])    # 转换为行向量

    tf_image = tf.reshape(tf_image, [resize_height, resize_width, 3])  # 设置图像的维度

    # 存储的图像类型为uint8,这里需要将类型转为tf.float32

    # tf_image = tf.cast(tf_image, tf.float32)

    # [1]若需要归一化请使用:

    tf_image = tf.image.convert_image_dtype(tf_image, tf.float32)  # 归一化

    # tf_image = tf.cast(tf_image, tf.float32) * (1. / 255)  # 归一化

    # [2]若需要归一化,且中心化,假设均值为0.5,请使用:

    # tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) - 0.5 #中心化

    return tf_image, tf_height, tf_width, tf_depth, tf_label


def disp_records(record_file, resize_height, resize_width, show_nums=4):
    '''

    解析record文件,并显示show_nums张图片,主要用于验证生成record文件是否成功

    :param tfrecord_file: record文件路径

    :param resize_height:

    :param resize_width:

    :param show_nums: 默认显示前四张照片



    :return:

    '''

    tf_image, tf_height, tf_width, tf_depth, tf_label = read_records(record_file, resize_height, resize_width)  # 读取函数

    # 显示前show_nums个图片

    init_op = tf.initialize_all_variables()

    with tf.Session() as sess:
        sess.run(init_op)

        coord = tf.train.Coordinator()

        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        for i in range(show_nums):
            image, height, width, depth, label = sess.run(
                [tf_image, tf_height, tf_width, tf_depth, tf_label])  # 在会话中取出image和label

            # image = tf_image.eval()

            # 直接从record解析的image是一个向量,需要reshape显示

            # image = image.reshape([height,width,depth])

            print('shape:', image.shape, 'label:', label)

            # pilimg = Image.fromarray(np.asarray(image_eval_reshape))

            # pilimg.show()

            show_image("image:%d" % (label), image)

        coord.request_stop()

        coord.join(threads)


def batch_test(record_file, resize_height, resize_width):
    '''

    :param record_file: record文件路径

    :param resize_height:

    :param resize_width:

    :return:

    :PS:image_batch, label_batch一般作为网络的输入

    '''

    tf_image, tf_height, tf_width, tf_depth, tf_label = read_records(record_file, resize_height, resize_width)  # 读取函数

    # 使用shuffle_batch可以随机打乱输入:

    # shuffle_batch用法:https://blog.csdn.net/u013555719/article/details/77679964

    min_after_dequeue = 100  # 该值越大,数据越乱,必须小于capacity

    batch_size = 4

    # capacity = (min_after_dequeue + (num_threads + a small safety margin∗batchsize)

    capacity = min_after_dequeue + 3 * batch_size  # 容量:一个整数,队列中的最大的元素数

    image_batch, label_batch = tf.train.shuffle_batch([tf_image, tf_label],

                                                      batch_size=batch_size,

                                                      capacity=capacity,

                                                      min_after_dequeue=min_after_dequeue)

    init = tf.global_variables_initializer()

    with tf.Session() as sess:  # 开始一个会话

        sess.run(init)

        coord = tf.train.Coordinator()

        threads = tf.train.start_queue_runners(coord=coord)

        for i in range(4):
            # 在会话中取出images和labels

            images, labels = sess.run([image_batch, label_batch])

            # 这里仅显示每个batch里第一张图片

            show_image("image", images[0, :, :, :])

            print(images.shape, labels)

        # 停止所有线程

        coord.request_stop()

        coord.join(threads)


if __name__ == '__main__':
    # 参数设置

    image_dir = r'F:\DATA\test data'

    train_file = r'F:\DATA\test data\train.txt'  # 图片路径

    output_record_txt = r'F:\DATA\test data\tfrecords.txt'  # 指定保存record的文件列表

    resize_height = 224  # 指定存储图片高度

    resize_width = 224  # 指定存储图片宽度

    batchSize = 8000  # batchSize一般设置为8000,即每batchSize张照片保存为一个record文件

    # 产生record文件

    create_records(image_dir=image_dir,

                   file=train_file,

                   record_txt_path=output_record_txt,

                   batchSize=batchSize,

                   resize_height=resize_height,

                   resize_width=resize_width)

    # 测试显示函数

    disp_records(output_record_txt, resize_height, resize_width)

    # batch_test(output_record_txt,resize_height, resize_width)

 效果图:

生成的tfrecords文件: 

 

参考博客https://blog.csdn.net/guyuealian/article/details/80857228

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值