def load_image(img_path,size = (32,32)):
label = tf.constant(1,tf.int8) if tf.strings.regex_full_match(img_path,".*/automobile/.*") \
else tf.constant(0,tf.int8)
img = tf.io.read_file(img_path)
img = tf.image.decode_jpeg(img)
img = tf.image.resize(img,size)/255.0
return(img,label)
ds_train = tf.data.Dataset.list_files("/content/drive/My Drive/eat_tf2_in_30_days_code/data/cifar2/train/*/*.jpg") \
.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
.shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
.prefetch(tf.data.experimental.AUTOTUNE)
- 这里的
buffer_size
是采样池的大小,ds_train
包含图片和标签