Tensorflow数据集制作专题【二】— 将图片制作成内存对象数据集

该博客介绍了如何将MNIST数据集转换为内存对象数据集,适用于TensorFlow使用。内容包括从已解压的MNIST图片开始,通过代码实现数据集的内存化处理。
摘要由CSDN通过智能技术生成

本实例采用的是mnist数据集, 我的数据集已经是解压好的图片,需要mnist数据集的请在评论区留言,看到第一时间回复, 谢谢。 

接下来我们来看看具体的过程, 废话不多, 上代码:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @ProjectName : 02_load_mnist.py
# @DateTime :  2019-11-23 10:21
# @Author : 皮皮虾

import os
import argparse
import logging
import numpy as np
import tensorflow as tf
from sklearn.utils import shuffle
from matplotlib import pyplot as plt


def load_mnist_data(src_path):
    images_path_list = []
    labels_list = []
    sub_dir = os.listdir(path=src_path)
    for _dir_ in sub_dir:
        image_dirname_path = os.path.join(src_path, _dir_)
        images_list = os.listdir(image_dirname_path)
        for image in images_list:
            image_path = os.path.join(image_dirname_path, image)
            # get each image path
            images_path_list.append(image_path)
            # get each image label
            labels_list.append(_dir_)
    label = list(sorted(sub_dir))

    return shuffle(np.asarray(images_path_list), list(map(int, np.asarray(labels_list)))), np.asarray(label)


def get_batches(image_path, label, batch_size, resize_height=28, resize_width=28, channels=1):
    # create input queue
    queue = tf.train.slice_input_producer(tensor_list=[image_path, label])
    # get label from input queue
    label = queue[1]
    # get tensor image path of type string
    _image_path = tf.read_file(filename=queue[0])
    # decode image
    image = tf.image.decode_bmp(contents=_image_path, channels=channels)
    # resize image
    image = tf.image.resize_image_with_crop_or_pad(image=image,
                                                   target_height=resize_height,
                                                   target_width=resize_width)
    """
    图像的标准化是将数据通过去均值实现中心化的处理,根据凸优化理论和数据概率分布相关的知识,数据中心化
    符合数据分布规律,更容易取得训练之后的泛化效果,数据标准化是数据预处理常见的方法之一
    """
    # process image to standard
    image = tf.image.per_image_standardization(image=image)
    # get batch_size data
    image_batch, label_batch = tf.train.batch(tensors=[image, label],
                                              batch_size=batch_size,
                                              num_threads=64)
    # convert image data type to float32
    images_batch = tf.cast(x=image_batch, dtype=tf.float32)
    # reshape label
    labels_batch = tf.reshape(tensor=label_batch, shape=[batch_size])

    return images_batch, labels_batch


def show_single_image(subplot, label, image):
    plt.subplot(subplot)
    plt.axis("off")
    plt.imshow(np.reshape(a=image, newshape=[28, 28]))
    plt.title(label=label)


def show_batch_image(label, image, top):
    plt.figure(figsize=(20, 10))
    plt.axis("off")
    top = min(top, 9)
    for i in range(top):
        show_single_image(subplot=100 + 10 * top + 1 + i, label=label[i], image=image[i])
    plt.show()


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(filename)s - %(lineno)s - %(levelname)s - %(message)s')
    logger = logging.getLogger(__name__)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input_path",
        default=r" ",
        type=str,
        required=True,
        help="the mnist data input path, like as 'mnist_digits_images'"
    )
    parser.add_argument(
        "--batch_size",
        default=16,
        type=int,
        required=True,
        help="batch size"
    )
    FLAGS, _ = parser.parse_known_args()
    logger.info({"FLAGS": FLAGS})

    (images_path, labels), _ = load_mnist_data(src_path=FLAGS.input_path)

    image_batchs, label_batchs = get_batches(image_path=images_path,
                                             label=labels,
                                             batch_size=FLAGS.batch_size)

    # start session
    with tf.Session() as sess:
        # initial global variables
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        # 创建队列协调器
        coord = tf.train.Coordinator()
        # 启动线程
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            for step in range(10):
                if coord.should_stop():
                    break
                else:
                    images, labels = sess.run([image_batchs, label_batchs])
                    show_batch_image(label=labels, image=images, top=FLAGS.batch_size)
                    print("step{}".format(step))
                    print("labels:", labels)
        except tf.errors.OutOfRangeError:
            print("finish!")
        finally:
            coord.request_stop()
        coord.join(threads=threads)

关于代码的实现, 就介绍到这里。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值