TensorFlow学习笔记——训练神经网络模型的第一步:构建自己的图片数据集TFRecords

0. 导读

相信很多深度学习的初学者都看着大神教程搭过很多的CNN、RNN、VGG等网络模型,训练模型时用的基本都是mnist、fashion-mnist、cifar10等基准数据集,当面对自己实验室的或者自己网上爬下来的数据集时就会犯难,根本不知道怎么调整算法模型去读取我们的本地数据集。问题不大,这个对于每一个深度学习工作者都是漫漫探索路上的必经之路。以下我将分享自己的学习心得以及完整代码实现,为曾和我一样正犯难的小伙伴提供一些参考。

代码部分已验证无误,并做了完整注释,请放心食用!

关于读取数据,TensorFlow提供了3种读取方法:

  1. Feeding: placeholder,feed_dict由占位符代替数据,运行时在线填入数据;
  2. Reader: 从文件中直接读取,在一个计算图(tf.graph)的开始前,将文件读入队列(queue)中;
  3. Preloaded data: 预加载数据;

鉴于TensorFlow提供了标准的TFRecord格式,接下来我将介绍就是上述的第2种方法。利用tf.record标准口来读入文件。本程序主要包含以下3个核心部分:

制作TFrecord
读取TFrecord数据获得image和label
打印验证并保存生成的图片

1. 准备数据

先在网上下载不同类的图片集,例如几个品种的狗的图片。实验室数据集不方便公开,暂时简单使用网上下载的Dog的图片做介绍。此处已预先下载哈奇士、吉娃娃两种狗的照片各30张,如下:

在这里插入图片描述

2. 代码拆分解释

2.1 制作TFrecord

#将原始图片转换成需要的大小,并将其保存
#=========================================================================================
import os
import tensorflow as tf
from PIL import Image

#原始图片的储存位置
orig_picture = 'C:/Users/94092/Desktop/Src/tensorflow_test/data/50class/train'

#生成图片的储存位置
gen_picture = 'C:/Users/94092/Desktop/Src/tensorflow_test/data/50class/Re_train/image_data/inputdata'

#需要的识别类型
classes = {'husky', 'jiwawa'}

#样本总数
num_samples = 60

#制作TFRecords数据
def create_record():
    writer = tf.python_io.TFRecordWriter(gen_picture + "/dogs_train.tfrecords")
    for index, name in enumerate(classes):
        class_path = orig_picture + "/" + name + "/"
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((64,64))     #设置需要转换的图片大小
            img_raw = img.tobytes()       #将图片转化为原生bytes
            print(index,img_raw)
            example = tf.train.Example(
                features=tf.train.Features(feature={
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                    "img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                }))
            writer.write(example.SerializeToString())
    writer.close()

以下代码段为单独测试部分,结合上述代码块即可制作狗的训练集TFRecord数据,命名为dogs_train.tfrecords,如生成文件如下图所示:

if __name__ == '__main__':
     create_record()
     print("Finished")

在这里插入图片描述

2.2 读取TFRecord数据获得image 和 label

def read_and_decode(filename):
    #创建文件队列,不限读取的数量
    filename_queue = tf.train.string_input_producer([filename])
    #create a reader from file queue
    reader = tf.TFRecordReader()
    #reader从文件队列中读入一个序列化的样本
    _, serialized_example = reader.read(filename_queue)
    #get feature from serialized example
    #解析符号化的样本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string)})
    label = features['label']
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)
    img = tf.reshape(img, [64,64,3])
    #img = tf.cast(img, tf.float32) * (1. /255) -0.5
    label = tf.cast(label, tf.int32)
    return img, label

2.3 打印验证并保存生成的图片

if __name__ == '__main__':
    # create_record()
    batch = read_and_decode('dogs_train.tfrecords')
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    with tf.Session() as sess:    #开启一个会话
        sess.run(init_op)
        coord = tf.train.Coordinator()   #多线程管理器
        threads = tf.train.start_queue_runners(coord=coord)

        for i in range(num_samples):
            example, lab = sess.run(batch)              #在会话中取出image 和label
            img = Image.fromarray(example, 'RGB')       #这里Image是之前提到的   Image.fromarray()实现array 到image的转换
            img.save(gen_picture + '/' + str(i) + '_Label_' + str(lab) + '.jpg')  #保存图片;注意cwd后面加上‘/’
        coord.request_stop()
        coord.join(threads)
        sess.close()

3. 完整代码实现

"""
**利用TensorFlow训练自己的图片数据(1)——预处理
@author: <Colynn Johnson>
@direct: https://blog.csdn.net/ywx1832990/article/details/78609323
@date: 2020-8-20
"""

"""
首先,我们需要准备训练的原始数据,本次训练为图像分类识别,因而一开始,笔者从网上随机的下载了Dog的四种类别:
husky,jiwawa。每种类别30种,一共60张图片。在训练之前,需要做的就是进行图像的预处理,
即将这些大小不一的原始图片转换成我们训练需要的shape。

编程实现包括:制作TFrecord,读取TFrecord数据获得image和label,打印验证并保存生成的图片
"""

#将原始图片转换成需要的大小,并将其保存
#=========================================================================================
import os
import tensorflow as tf
from PIL import Image

#原始图片的储存位置
orig_picture = 'C:/Users/94092/Desktop/Src/tensorflow_test/data/50class/train'

#生成图片的储存位置
gen_picture = 'C:/Users/94092/Desktop/Src/tensorflow_test/data/50class/Re_train/image_data/inputdata'

#需要的识别类型
classes = {'husky', 'jiwawa'}

#样本总数
"""待定!!"""
num_samples = 60

#制作TFRecords数据
def create_record():
    writer = tf.python_io.TFRecordWriter(gen_picture + "/dogs_train.tfrecords")
    for index, name in enumerate(classes):
        class_path = orig_picture + "/" + name + "/"
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((64,64))     #设置需要转换的图片大小
            img_raw = img.tobytes()       #将图片转化为原生bytes
            print(index,img_raw)
            example = tf.train.Example(
                features=tf.train.Features(feature={
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                    "img_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
                }))
            writer.write(example.SerializeToString())
    writer.close()

# if __name__ == '__main__':
#     create_record()
#     print("Finished")

#=================================================================================================================
"""
读取TFRecord数据获得image 和 label
"""
def read_and_decode(filename):
    #创建文件队列,不限读取的数量
    filename_queue = tf.train.string_input_producer([filename])
    #create a reader from file queue
    reader = tf.TFRecordReader()
    #reader从文件队列中读入一个序列化的样本
    _, serialized_example = reader.read(filename_queue)
    #get feature from serialized example
    #解析符号化的样本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string)})
    label = features['label']
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)
    img = tf.reshape(img, [64,64,3])
    #img = tf.cast(img, tf.float32) * (1. /255) -0.5
    label = tf.cast(label, tf.int32)
    return img, label

#========================================================================================
"""
打印验证并保存生成的图片
"""
if __name__ == '__main__':
    create_record()
    batch = read_and_decode('dogs_train.tfrecords')
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    with tf.Session() as sess:    #开启一个会话
        sess.run(init_op)
        coord = tf.train.Coordinator()   #多线程管理器
        threads = tf.train.start_queue_runners(coord=coord)

        for i in range(num_samples):
            example, lab = sess.run(batch)              #在会话中取出image 和label
            img = Image.fromarray(example, 'RGB')       #这里Image是之前提到的   Image.fromarray()实现array 到image的转换
            img.save(gen_picture + '/' + str(i) + '_Label_' + str(lab) + '.jpg')  #保存图片;注意cwd后面加上‘/’
        coord.request_stop()
        coord.join(threads)
        sess.close()

结果展示

在这里插入图片描述
每一幅图片的命名中,第二个数字则是 label,吉娃娃都为1,哈士奇都为0;通过对照图片,可以发现图片分类正确。

Reference: https://blog.csdn.net/ywx1832990/article/details/78609323

  • 0
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值