将自己的数据集制作成tf格式,文件批量重命名

主要参考博客:data_to_tf

在进行深度学习时,如何将自己的数据制作成tf格式是要关注的第一步。结合看到的多篇博客论文,并整理成python code格式进行记录及分享。

import tensorflow as tf
import os
from PIL import Image
import numpy as np

IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL = 128, 128, 1


# 生成整数型的属性
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# 生成字符串类型的属性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


# 制作TFRecord格式
def createTFRecord(filename, mapfile):
    '''
    :param filename: output path
    :param mapfile:
    :return:
    '''
    class_map = {}
    data_dir = '/home/sxf/MyProject_Python/normal_code/data_make/my_data_to_tf'
    classes = {'/ori', '/new'}
    # 输出TFRecord文件的地址
    writer = tf.python_io.TFRecordWriter(filename)
    for index, name in enumerate(classes):
        class_path = data_dir + name + '/'
        class_map[index] = name
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name  # 每个图片的地址
            img = Image.open(img_path)
            #  获得当前原始图片的shape,当前的shape不能大于resize之后的大小
            #  print(np.shape(img))
            img = img.resize((IMAGE_HEIGHT, IMAGE_WIDTH))  # 进行resize
            img_raw = img.tobytes()  # 将图片转化成二进制格式
            example = tf.train.Example(features=tf.train.Features(feature={
                'label': _int64_feature(index),
                'image_raw': _bytes_feature(img_raw)
            }))
            # print('example', example)
            writer.write(example.SerializeToString())
    writer.close()
    txtfile = open(mapfile, 'w+')
    for key in class_map.keys():
        txtfile.writelines(str(key) + ":" + class_map[key] + "\n")
    txtfile.close()


# 读取生成的tfrecord,并进行resize
def read_and_decode(filename):
    # 创建一个reader来读取TFRecord文件中的样例
    reader = tf.TFRecordReader()
    # 创建一个队列来维护输入文件列表
    filename_queue = tf.train.string_input_producer([filename], shuffle=False, num_epochs=1)
    # 从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
    _, serialized_example = reader.read(filename_queue)
    #     print _,serialized_example

    # 解析读入的一个样例,如果需要解析多个,可以用parse_example
    features = tf.parse_single_example(
        serialized_example,
        features={'label': tf.FixedLenFeature([], tf.int64),
                  'image_raw': tf.FixedLenFeature([], tf.string), })
    # 将字符串解析成图像对应的像素数组
    img = tf.decode_raw(features['image_raw'], tf.uint8)
    img = tf.reshape(img, [IMAGE_HEIGHT, IMAGE_HEIGHT, IMAGE_CHANNEL])  # reshape为128*128*3通道图片
    img = tf.image.per_image_standardization(img)
    labels = tf.cast(features['label'], tf.int32)
    return img, labels


# 生成batch
def createBatch(filename, batchsize):
    images, labels = read_and_decode(filename)

    min_after_dequeue = 10
    capacity = min_after_dequeue + 3 * batchsize

    image_batch, label_batch = tf.train.shuffle_batch([images, labels],
                                                      batch_size=batchsize,
                                                      capacity=capacity,
                                                      min_after_dequeue=min_after_dequeue
                                                      )

    label_batch = tf.one_hot(label_batch, depth=2)
    return image_batch, label_batch


if __name__ == "__main__":
    # 训练图片两张为一个batch,进行训练,测试图片一起进行测试
    mapfile = '/home/sxf/MyProject_Python/normal_code/data_make/my_data_to_tf/classmap.txt'
    train_filename = '/home/sxf/MyProject_Python/normal_code/data_make/my_data_to_tf/train.tfrecords'
    createTFRecord(train_filename, mapfile)
    test_filename = '/home/sxf/MyProject_Python/normal_code/data_make/my_data_to_tf/test.tfrecords'
    createTFRecord(test_filename, mapfile)
    image_batch, label_batch = createBatch(filename=train_filename, batchsize=2)
    test_images, test_labels = createBatch(filename=test_filename, batchsize=2)
    with tf.Session() as sess:
        initop = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(initop)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            step = 0
            while 1:
                _image_batch, _label_batch = sess.run([image_batch, label_batch])
                step += 1
                print(step)
                print(_label_batch)
        except tf.errors.OutOfRangeError:
            print(" trainData done!")

        try:
            step = 0
            while 1:
                _test_images, _test_labels = sess.run([test_images, test_labels])
                step += 1
                print(step)
                #                 print _image_batch.shape
                print(_test_labels)
        except tf.errors.OutOfRangeError:
            print(" TEST done!")

        coord.request_stop()
        coord.join(threads)
存在的注意点:

原始的数据中图片的格式要要保证通道数的一致,不然resize要出错。

另一个要解决的点就是进行图片重命名的问题。

直接上代码:

# 将原始路径下的图片进行重命名,并复制保存到新的路径下,最终返回新的路径下的文件名列表。
import os
import shutil


def rename(path_ori, newpath,flage=True):
    newname_front = input("please input the new name style:")
    print('new name is the format like %s_1.....'%(newname_front))
    newname_front = newname_front.strip()
    file_list = os.listdir(path_ori)
    i = 0
    for file in file_list:
        i += 1
        olddir = os.path.join(path_ori, file)
        if os.path.isdir(olddir):
            continue
        filename = os.path.splitext(file)[0]
        filetype = os.path.splitext(file)[1]
        newname = newname_front + '_' + str(i)
        rename_dir = os.path.join(path_ori, newname + filetype)
        rename_new_dir = os.path.join(newpath, newname + filetype)
        os.rename(olddir, rename_dir)
        # savedatapath = os.path.join(strangedatafile, filename)
        if flage:
            shutil.copyfile(rename_dir, rename_new_dir)
    newfile_list = os.listdir(newpath)
    return newfile_list


###########################//
#test
###########################//
path = '/home/sxf/MyProject_Python/ori_image/new'
newpath = '/home/sxf/MyProject_Python/ori_image/ori'
list = rename(path, newpath,flage=False)
print(list)


  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值