TensorFlow 2.0是对1.x版本做了一次大的瘦身,Eager Execution默认开启,并且使用Keras作为默认高级API,
这些改进大大降低的TensorFlow使用难度。
本文主要记录了一次曲折的使用Keras+TensorFlow2.0的BatchNormalization的踩坑经历,这个坑差点要把TF2.0的新特性都毁灭殆尽,如果你在学习TF2.0的官方教程,不妨一观。
问题的产生
从教程[1]https://www.tensorflow.org/alpha/tutorials/images/transfer_learning?hl=zh-cn(讲述如何Transfer Learning)说起:
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,weights='imagenet')
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(NUM_CLASSES)
])
简单的代码我们就复用了MobileNetV2的结构创建了一个分类器模型,接着我们就可以调用Keras的接口去训练模型:
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])
model.summary()
history = model.fit(train_batches.repeat(),
epochs=20,
steps_per_epoch = steps_per_epoch,
validation_data=validation_batches.repeat(),
validation_steps=validation_steps)
输出的结果看,一起都很完美:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
mobilenetv2_1.00_160 (Model) (None, 5, 5, 1280) 2257984
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280) 0
_________________________________________________________________
dense (Dense) (None, 2) 1281
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________
Epoch 11/20
581/581 [==============================] - 134s 231ms/step - loss: 0.4208 - accuracy: 0.9484 - val_loss: 0.1907 - val_accuracy: 0.9812
Epoch 12/20
581/581 [==============================] - 114s 197ms/step - loss: 0.3359 - accuracy: 0.9570 - val_loss: 0.1835 - val_accuracy: 0.9844
Epoch 13/20
581/581 [==============================] - 116s 200ms/step - loss: 0.2930 - accuracy: 0.9650 - val_loss: 0.1505 - val_accuracy: 0.9844
Epoch 14/20
581/581 [==============================] - 114s 196ms/step - loss: 0.2561 - accuracy: 0.9701 - val_loss: 0.1575 - val_accuracy: 0.9859
Epoch 15/20
581/581 [==============================] - 119s 206ms/step - loss: 0.2302 - accuracy: 0.9715 - val_loss: 0.1600 - val_accuracy: 0.9812
Epoch 16/20
581/581 [==============================] - 115s 197ms/step - loss: 0.2134 - accuracy: 0.9747 - val_loss: 0.1407 - val_accuracy: