1. 获取图片路径并生成标签列表
- 获取所有图片文件路径(
glob.glob
) - 根据图片类别(上一层文件夹名),生成标签列表(并编码成0,1,2…格式)
- 将一维标签列表转化成二维列表(
tf.reshape(test_image_label,[1])
)
def get_image_path():
train_image_path=glob.glob('./data/dc_2000/train/*/*.jpg')
test_image_path=glob.glob('./data/dc_2000/test/*/*.jpg')
train_image_label=[1 if elem.split('\\')[1]=='dog' else 0 for elem in train_image_path]
test_image_label=[1 if elem.split('\\')[1]=='dog' else 0 for elem in train_image_path]
train_image_label=tf.reshape(train_image_label,[1])
test_image_label=tf.reshape(test_image_label,[1])
2. 图像预处理
- 以二进制形式读取图片文件(
tf.io.read_file(path)
) - 将二进制文件解码成相应类型的图片(
tf.image.decode_jpeg(image,channels=3)
注意通道数量) - 将图片统一尺寸(
tf.image.resize(image,[256,256])
) - 改变图片每位像素的数据类型(
tf.image.cast(image,tf.float32)
) - 对每张图片的所有像素值进行归一化,除以像素范围最大值(
image=image/255
)
注:tf.image.convert_image_dtype
函数会将图片格式转化为float32
,并执行归一化,如果原数据类型是float32
,则不会进行数据归一化的操作
def load_process_file(filepath):
image = tf.io.read_file(filepath)
image = tf.image.decode_jpeg(image,channels=3)
image = tf.image.resize(image,[256,256])
image=tf.cast(image,tf.float32)/255
# image = tf.image.convert_image_dtype #次函数会将图片格式转化为float32,并执行归一化,如果原数据类型是float32,则不会进行数据归一化的操作
3. 生成dataset
- 使用
tf.data.Dataset.from_tensor_slices((图片路径列表,标签列表))
生成dataset
- 对生成的
dataset
执行图像预处理函数 (dataset.map(load_process_file,num_parallel_calls=tf.data.experimental.AUTOTUNE)
) - 对样本进行乱序(
dataset.shuffle(1000)
) - 对样本划分 batch(
datase.batch(32)
) - 使用
prefetch(tf.data.experimental.AUTOTUNE)
增加图片读取速度
train_ds=tf.data.Dataset.from_tensor_slices((train_image_path,train_image_label))
train_ds.map(load_process_file,num_parallel_calls=tf.data.experimental.AUTOTUNE) #使用多线程,线程数自适应
test_ds = tf.data.Dataset.from_tensor_slices((test_image_path, test_image_label))
test_ds.map(load_process_file, num_parallel_calls=tf.data.experimental.AUTOTUNE) # 使用多线程,线程数自适应
BATCH_SIZE=32
train_count=len(train_image_path)
test_count=len(test_image_path)
train_ds=train_ds.repeat().shuffle(train_count).batch(BATCH_SIZE)
test_ds=test_ds.batch(BATCH_SIZE)
train_ds=train_ds.prefetch(tf.data.experimental.AUTOTUNE)
test_ds=test_ds.prefetch(tf.data.experimental.AUTOTUNE)
imgs,labels =next(iter(train_ds))