实例5:将图片文件制作成TFRecord数据集

实例5:将图片文件制作成TFRecord数据集

有两个文件夹,放置男人和女人的照片
要求

  • 将两个文件夹中的图片制成TFRecord格式的数据集
  • 从数据集中读取数据,并将得到的图片数据保存到本地文件

TFRecord格式与TensorFlow框架绑定,通用性较差
但它是一种非常高效的数据持久化方法,尤其对需要预处理的样本集
将处理后的数据用TFRecord格式保存训练,可以大大提高训练模型的运行效率

1. 样本介绍

  • 文件夹的名称可以当做样本标签
  • 文件夹中的具体图片文件可被当做具体的样本数据

2. 代码实现:读取样本文件的目录及标签

import os 
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()

from PIL import Image
from sklearn.utils import shuffle
import numpy as np
from tqdm import tqdm

def load_sample(sample_dir, shuffleflag = True):
    print("loading sample dataset..")
    lfilenames = []
    labelsnames = []
    for (dirpath, dirnames, filenames) in os.walk(sample_dir):
        for filename in filenames:
            filename_path = os.sep.join([dirpath, filename])
            lfilenames.append(filename_path)
            labelsnames.append(dirpath.split('\\')[-1])
    
    lab = list(sorted(set(labelsnames)))
    labdict = dict(zip(lab, list(range(len(lab)))))
    labels = [labdict[i] for i in labelsnames]
    
    if shuffleflag == True:
        return (shuffle(np.asarray(lfilenames), np.asarray(labels))), np.asarray(lab)
    return (np.asarray(lfilenames), np.asarray(labels)), np.asarray(lab)
    
directory = 'man_woman\\'
(filenames, labels), _ = load_sample(directory, shuffleflag=False)

引入第三方库tqdm,以便在批处理过程中显示进度

3. 代码实现:定义函数生成TFRecord

def makeTFRec(filenames,labels):
    writer = tf.python_io.TFRecorWriter("mydata.tfrecords")
    for i in tqdm(range(0,len(labels))):
        img = Image.open(filenames[i])
        img = img.resize((256,256))
        img_raw = img.tobytes()
        example = tf.train.Example(features=tf.train.Features(feature={
                                        "label":   
                                   tf.train.Feature(int64_list=tf.train.Int64List(value=[labels[i]]),
                                        'img_raw':  
                                    tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                                    }))
        writer.write(example.SerializeToString())  
        writer.close()

makeTFRec(filenames, labels)

TFRecordWriter类
tf.train.Feature()特征
tf.trian.Features()和tf.train.Example()

4. 代码实现:读取TFRecord数据集,并将其转化为队列

函数read_and_decode支持两种模式的队列格式转换:

  • 训练模式:对数据集进行乱序操作,并将其按照指定批次组合起来
  • 测试模式:按照顺序读取数据集一次
def read_and_decode(filenames, flag='train', batch_size=3):
    if flag == 'train':
        filename_queue = tf.train.string_input_producer(filenames)  #已经进行乱序读取了
    else:
        filename_queue = tf.trina.string_input_producer(filenames, nun_epochs=1, shuffle=False) #取一个批次,并且顺序
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features = {
                                           'label': tf.FixedLenFeature([],tf.int64),
                                           'img_raw': tf.FixEdLenFeature([],tf.string),
                                           })
    #tf.decode_raw可以将字符串解析成图像对应的像素数组
    image = tf.decode_raw(features['img_raw'], tf.uint8)
    image = tf.reshape(image, [256,256,3])
    label = tf.cast(features['label'],tf.int32)
    
    if flag == 'train':
        image = tf.cast(image,tf.float32) * (1. / 255) - 0.5    #训练是将其归一化
        img_batch, label_batch = tf.train.batch([image, label], batch_size=batch_size, capacity=20)
        return img_batch, label_batch
    return image, label

TFRecordfilenames = ["mydata.tfrecord"]
image, label = read_and_decode(TFRecordfilenames, flag='test')

tf.TFRecordreader()

5. 代码实现:建立会话,将数据保存到文件

saveimgpath = 'show\\'

if tf.gfile.Exists(saveimgpath):     #如果存在saveimgpath,则将其删除
    tf.gfile.DeleteRecursively(saveimgpath)
tf.gfile.MakeDirs(saveimgpath)

with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    myset = set([])     #建立集合对象,用于存放子文件夹
    
    try:
        print("start")
        i = 0
        while True:
            example, examplelab = sess.run([image, label])
            print("2")
            examplelab = str(examplelab)
            if examplelab not in myset:
                myset.add(examplelab)
                tf.gfile.MakeDirs(saveimgpath+examplelab)
            img = Image.fromarray(example, 'RGB')   #转换成image格式
            img.save(saveimgpath+examplelab + '/' + str(i) + '_Label' + '.jpg')     #保存图片
            print(i)
            i = i + 1
    except tf.errors.OutOfRangeError:
        print("Done Test -- epoch limit reached")
    finally:
        coord.request_stop()
        print("stop()")
    coord.join(threads)
    print("stop()")
    sess.close()

‘utf-8’ codec can’t decode byte 0xd5 in position 105: invalid continuation byte错误

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值