import tensorflow as tf
import glob as glob
def load_process_image(path,label):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image,channels=3)
image = tf.image.resize(image,[256,256])
image = image/255
return image,label
if __name__ == '__main__':
train_image_path = glob.glob('.\dc_2000\\train\*\*.jpg')#['.', 'dc_2000', 'train', 'cat', 'cat.0.jpg']
train_image_label = [int(p.split('\\')[3] == 'cat') for p in train_image_path]#1为cat
train_image_ds = tf.data.Dataset.from_tensor_slices((train_image_path,train_image_label))
AUTOTUNE = tf.data.experimental.AUTOTUNE#map并行处理
train_image_ds = train_image_ds.map(load_process_image, num_parallel_calls=AUTOTUNE)
BATCH_SIZE = 32
train_count = len(train_image_path)
train_image_ds = train_image_ds.shuffle(train_count).repeat().batch(BATCH_SIZE)
test_image_path = glob.glob('.\dc_2000\\train\*\*.jpg')
test_image_label = [int(p.split('\\')[3] == 'cat') for p in train_image_path]
test_image_ds = tf.data.Dataset.from_tensor_slices((test_image_path, test_image_label))
test_image_ds = test_image_ds.map(load_process_image, num_parallel_calls=AUTOTUNE)
test_image_ds = test_image_ds.repeat().batch(BATCH_SIZE)
test_count = len(test_image_path)
#keras内置经典网络
covn_base = tf.keras.applications.xception.Xception(
weights='imagenet',#使用在imagenet训练好权重
include_top=False,#只引入卷积基,不引入输出层
input_shape=(256,256,3),
pooling='avg'
)
model = tf.keras.Sequential()
model.add(covn_base)
model.add(tf.keras.layers.Dense(512,activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
#covn_base.trainable = False,xception需要可训练
model.compile(
optimizer='adam',
loss=tf.keras.losses.binary_crossentropy,
metrics=['acc']
)
history = model.fit(
train_image_ds,
steps_per_epoch=train_count//BATCH_SIZE,
epochs=5,
validation_data=test_image_ds,
validation_steps=test_count//BATCH_SIZE
)
print(history)
微调
只有分类器已经训练好了,才能微调卷积基的顶部卷积层。如果有没有这样的话,刚开始的训练误差很大,微调之前这些卷积层学到的表示会被破坏掉
- 在预训练卷积基上添加自定义层
- 冻结卷积基所有层
- 训练添加的分类层
- 解冻卷积基的一部分层
- 联合训练解冻的卷积层和添加的自定义层
covn_base.trainable = True
for layer in covn_base.layers[:-33]:
layer.trainable=False#对后33个层训练