本实例采用的是mnist数据集, 我的数据集已经是解压好的图片,需要mnist数据集的请在评论区留言,看到第一时间回复, 谢谢。
接下来我们来看看具体的过程, 废话不多, 上代码:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @ProjectName : 02_load_mnist.py
# @DateTime : 2019-11-23 10:21
# @Author : 皮皮虾
import os
import argparse
import logging
import numpy as np
import tensorflow as tf
from sklearn.utils import shuffle
from matplotlib import pyplot as plt
def load_mnist_data(src_path):
images_path_list = []
labels_list = []
sub_dir = os.listdir(path=src_path)
for _dir_ in sub_dir:
image_dirname_path = os.path.join(src_path, _dir_)
images_list = os.listdir(image_dirname_path)
for image in images_list:
image_path = os.path.join(image_dirname_path, image)
# get each image path
images_path_list.append(image_path)
# get each image label
labels_list.append(_dir_)
label = list(sorted(sub_dir))
return shuffle(np.asarray(images_path_list), list(map(int, np.asarray(labels_list)))), np.asarray(label)
def get_batches(image_path, label, batch_size, resize_height=28, resize_width=28, channels=1):
# create input queue
queue = tf.train.slice_input_producer(tensor_list=[image_path, label])
# get label from input queue
label = queue[1]
# get tensor image path of type string
_image_path = tf.read_file(filename=queue[0])
# decode image
image = tf.image.decode_bmp(contents=_image_path, channels=channels)
# resize image
image = tf.image.resize_image_with_crop_or_pad(image=image,
target_height=resize_height,
target_width=resize_width)
"""
图像的标准化是将数据通过去均值实现中心化的处理,根据凸优化理论和数据概率分布相关的知识,数据中心化
符合数据分布规律,更容易取得训练之后的泛化效果,数据标准化是数据预处理常见的方法之一
"""
# process image to standard
image = tf.image.per_image_standardization(image=image)
# get batch_size data
image_batch, label_batch = tf.train.batch(tensors=[image, label],
batch_size=batch_size,
num_threads=64)
# convert image data type to float32
images_batch = tf.cast(x=image_batch, dtype=tf.float32)
# reshape label
labels_batch = tf.reshape(tensor=label_batch, shape=[batch_size])
return images_batch, labels_batch
def show_single_image(subplot, label, image):
plt.subplot(subplot)
plt.axis("off")
plt.imshow(np.reshape(a=image, newshape=[28, 28]))
plt.title(label=label)
def show_batch_image(label, image, top):
plt.figure(figsize=(20, 10))
plt.axis("off")
top = min(top, 9)
for i in range(top):
show_single_image(subplot=100 + 10 * top + 1 + i, label=label[i], image=image[i])
plt.show()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(filename)s - %(lineno)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_path",
default=r" ",
type=str,
required=True,
help="the mnist data input path, like as 'mnist_digits_images'"
)
parser.add_argument(
"--batch_size",
default=16,
type=int,
required=True,
help="batch size"
)
FLAGS, _ = parser.parse_known_args()
logger.info({"FLAGS": FLAGS})
(images_path, labels), _ = load_mnist_data(src_path=FLAGS.input_path)
image_batchs, label_batchs = get_batches(image_path=images_path,
label=labels,
batch_size=FLAGS.batch_size)
# start session
with tf.Session() as sess:
# initial global variables
init_op = tf.global_variables_initializer()
sess.run(init_op)
# 创建队列协调器
coord = tf.train.Coordinator()
# 启动线程
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
for step in range(10):
if coord.should_stop():
break
else:
images, labels = sess.run([image_batchs, label_batchs])
show_batch_image(label=labels, image=images, top=FLAGS.batch_size)
print("step{}".format(step))
print("labels:", labels)
except tf.errors.OutOfRangeError:
print("finish!")
finally:
coord.request_stop()
coord.join(threads=threads)
关于代码的实现, 就介绍到这里。