1.数据集介绍
jpg图像数据格式的MNIST数据集:(放在database1文件夹下面)
2.tensorflow读取jpg图像数据格式的mnist数据集:
tensorflow1.x的读取方式:
tensorflow1.12以上的读取方式:(最好是1.13.1或者2.x)
https://blog.csdn.net/Black_Friend/article/details/104529859
import tensorflow as tf
import random
import pathlib
data_path = pathlib.Path('./database1/')
print(type(data_path))#<class 'pathlib.WindowsPath'>
all_image_paths = list(data_path.glob('*/*'))
print(type(data_path.glob('*/*')))#<class 'generator'>
# print(all_image_paths)
all_image_paths = [str(path) for path in all_image_paths] # 所有图片路径的列表
random.shuffle(all_image_paths) # 打散
# print(all_image_paths[0:3])
image_count = len(all_image_paths)
print('image_count: ',image_count)
label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir())
print('label_names: ',label_names)
label_to_index = dict((name, index) for index, name in enumerate(label_names))
print('label_to_index: ',label_to_index)
all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
db_train = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))
def load_and_preprocess_from_path_label(path, label):
image = tf.io.read_file(path) # 读取图片
image = tf.image.decode_jpeg(image, channels=3)
image = tf.cast(image, dtype=tf.float32) / 255.0
# image = tf.image.resize(image, [28, 28]) # 原始图片大小为(100, 100, 3),重设为(192, 192)
# image /= 255.0 # 归一化到[0,1]范围
label = tf.cast(label, dtype=tf.int32)
label = tf.one_hot(label, depth=10)
return image, label
db_train.shuffle(1000)
db_train.map(load_and_preprocess_from_path_label)
db_train.batch(64)
db_train.repeat(2)
print(type(db_train))#<class 'tensorflow.python.data.ops.dataset_ops.DatasetV1Adapter'>
print(db_train.output_shapes)#(TensorShape([]), TensorShape([]))