分类任务3——把弄好的数据制作成tfrecord

既然都用了tensorflow,那干脆数据文件也弄成这个格式算了。
(我绝对不会说是因为直接读取图像太慢了)

没错又是这个

"""

# 把图像数据制作成tfrecord

"""

import tensorflow as tf
import os
from PIL import Image
import random
from tqdm import tqdm


def _int64_feature(label):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))


def _bytes_feature(imgdir):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[imgdir]))


def float_list_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def get_example_nums(tf_records_filenames):
    nums = 0
    for record in tf.python_io.tf_record_iterator(tf_records_filenames):
        nums += 1
    return nums


def load_file(imagestxtdir, shuffle=False):
    images = []  # 存储各个集中图像地址的列表
    labels = []
    with open(imagestxtdir) as f:
        lines_list = f.readlines()  # 读取文件列表中所有的行
        if shuffle:
            random.shuffle(lines_list)
        for line in lines_list:
            line_list = line.rstrip().split(' ')  # rstrip函数删除指定字符,这里用的rstrip()因为括号内是空格,所以是删除空白部分
            label = []

            for i in range(1):
                label.append(int(line_list[i + 1]))
            # 这里本质就是要line_list[1],因为这个部分就是存label的,可以用下面一行直接替代
            # label.append(int(line_list[1]))

            # cur_img_dir=images_base_dir+'/'+line_list[0]
            images.append(line_list[0])
            labels.append(label)
    return images, labels


def create_tf_records(image_base_dir, image_txt_dir, tfrecords_dir,
                      resize_height, resize_width, log, shuffle):
    images_list, labels_list = load_file(image_txt_dir, shuffle)

    # 判断是否存在保存tfrecord文件的路径,如果没有,就创建一个。
    tf_dir, tf_name = os.path.split(tfrecords_dir)
    if not os.path.exists(tf_dir):
        os.makedirs(tf_dir)
    tfrecords_dir = tf_dir + '/' + tf_name
    # print(tfrecords_dir)

    writer = tf.python_io.TFRecordWriter(tfrecords_dir)
    # print('len is :', len(images_list))
    # image_name 这个函数虽然没有用到,但是作用仍十分关键。因为后面的zip要求有两个变量。
    print('\n#######################start to create %s###########################' % tf_name)
    for i, [image_name, single_label_list] in enumerate(zip(images_list, labels_list)):

        cur_image_dir = image_base_dir + '/' + images_list[i]

        if not os.path.exists(cur_image_dir):
            print('the image path is not exists')
            continue

        image = Image.open(cur_image_dir)
        image = image.resize((resize_height, resize_width))
        image_raw = image.tobytes()
        single_label = single_label_list[0]

        if i % log == 0 or i == len(images_list) - 1:
            print('------------processing:%d-th------------' % i)
        example = tf.train.Example(features=tf.train.Features(feature={
            'image_raw': _bytes_feature(image_raw),
            'label': _int64_feature(single_label)}))
        writer.write(example.SerializeToString())
    print('#######################successfully create %s###########################\n' % tf_name)
    writer.close()


if __name__ == '__main__':
    resize_height = 600
    resize_width = 600
    # shuffle = True
    log = 5

    train_image_dir = 'E:/111project/ship image/train'
    train_txt_dir = 'E:/111project/ship image/train.txt'
    train_records_dir = 'E:/111project/tfrecordss/train.tfrecords'
    create_tf_records(train_image_dir, train_txt_dir, train_records_dir,
                      resize_height, resize_width, log, shuffle=True)
    train_nums = get_example_nums(train_records_dir)
    print('the train records number is:', train_nums)

    validation_image_dir = 'E:/111project/ship image/validation'
    validation_txt_dir = 'E:/111project/ship image/validation.txt'
    validation_records_dir = 'E:/111project/tfrecordss/validation.tfrecords'
    create_tf_records(validation_image_dir, validation_txt_dir, validation_records_dir,
                      resize_height, resize_width, log, shuffle=True)
    validation_nums = get_example_nums(validation_records_dir)
    print('the validation records number is:', validation_nums)

    test_image_dir = 'E:/111project/ship image/test'
    test_txt_dir = 'E:/111project/ship image/test.txt'
    test_records_dir = 'E:/111project/tfrecordss/test.tfrecords'
    create_tf_records(test_image_dir, test_txt_dir, test_records_dir,
                      resize_height, resize_width, log, shuffle=False)
    test_nums = get_example_nums(test_records_dir)
    print('the test records number is:', test_nums)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值