迁移学习和微调
预加载
import numpy as np
import tensorflow as tf
from tensorflow import keras
简介
迁移学习包括采用在一个问题上学到的功能,并在新的类似问题上加以利用。例如,来自已学会识别浣熊的模型的特征可能对启动旨在识别tanukis的模型很有用。
迁移学习通常用于数据集数据太少模型的训练任务。
在深度学习中,迁移学习最常见的体现是以下工作流程:
- 从预训练的模型中提取层。
- 冻结一些层,以免在以后的训练中破坏它们包含的任何信息。
- 在冻结层的顶部添加一些新的可训练层。他们将学习将旧功能转变为对新数据集的预测。
- 在数据集上训练新图层。
最后一个可选步骤是微调,包括解冻已训练好的整个模型(或部分模型),并以非常低的学习率基于新数据进行重新训练。通过将预训练的功能逐步适应新数据,可以潜在地实现显著的提升。
首先,我们将详细介绍Keras可训练的API,它是大多数迁移学习和微调工作流的基础。
然后,我们将通过在ImageNet数据集上进行预训练的模型,然后在Kaggle“猫与狗”分类数据集上对其进行重新训练,来演示典型的工作流程。
改编自Deep Learning with Python和2016年博客文章使用很少的数据构建强大的图像分类模型-building powerful image classification models using very little data".
冻结层:理解可训练的属性
Layers & models具有三个权重属性:
- weights权重是该层的所有权重变量的列表。
- trainable_weights是要进行更新(通过梯度下降)以最大程度地减少训练过程中的损失的列表。
- non_trainable_weights是不适合训练的列表。 通常,它们在模型正向传递中进行更新。
示例:Dense层具有2个可训练的权重(kernel和bias)
weights: 2
trainable_weights: 2
non_trainable_weights: 0
Layers & models还具有布尔属性trainable。 其值可以更改。 将layer.trainable设置为False会将图层的所有权重从可训练变为不可训练。 这称为“冻结”层:冻结层的状态在训练期间不会更新(无论是使用fit()进行训练,还是使用依赖trainable_weights来应用梯度更新的任何自定义循环进行训练)。
示例: 把 trainable 属性设为 False。
weights: 2
trainable_weights: 0
non_trainable_weights: 2
当一个权重的可训练属性变为不可训练时,其值将不会随训练过程更新。
1/1 [==============================] - 0s 1ms/step - loss: 0.1275
不要将layer.trainable属性与layer .__ call __()中的参数training混淆(该参数控制该层的前向过程是应该以推理模式还是训练模式运行)。 有关更多信息,请参见Keras FAQ。
trainable 属性的递归设置
如果你在model类或者具有子图层的layer类中设定trainable = False,所有的子图层都会变为不可训练。
示例:
迁移学习典型的工作流程
下面介绍一下Keras中迁移学习的典型过程:
- 实例化基本模型并将预训练的权重加载到其中。
- 通过设置trainable = False冻结基本模型中的所有层。
- 在基础模型的一层(或几层)的输出之上创建一个新模型。
- 在新的数据集上训练该新模型。
此外,还有另一种更轻量的工作流程:
- 实例化基本模型并将预训练的权重加载到其中。
- 通过它运行新的数据集,并记录基础模型中一层(或几层)的输出。这称为特征提取。
- 使用该输出作为新的较小模型的输入数据。
第二个工作流程的一个关键优势是,您仅对数据运行一次基本模型,而不是每次训练一次。因此,它更快,更节约。
但是,第二个工作流程的问题是,它不允许您在训练期间动态修改新模型的输入数据。而进行数据扩充时,恰恰需要修改输入数据。当新数据集的数据太少而无法从头训练完整模型时,常常需要用到迁移学习,而在这种情况下,数据扩充非常重要。因此,在接下来的内容中,我们将集中讲解第一个工作流程。
这是Keras中第一个工作流程的样子:
首先,实例化具有预训练权重的基本模型。
然后,冻结基本模型
在顶部创建新模型
基于新数据训练模型
Fine-tuning:微调
一旦您的模型收敛于新数据,您就可以尝试解冻全部或部分基本模型,并以非常低的学习率端到端重新训练整个模型。
微调是最后一个可选步骤,可以为您带来逐步的改进。它还可能会导致快速过拟合-请记住这一点。
至关重要的是只有在训练具有冻结层的模型以使其收敛之后才执行此步骤。如果将随机初始化的可训练图层与包含预训练要素的可训练图层混合使用,则随机初始化的图层将在训练过程中引起很大的渐变更新,这将破坏您的预训练要素。
在此阶段使用非常低的学习率也很关键,因为在通常很小的数据集上,您训练的模型比第一轮训练中的模型大得多。结果,如果应用较大的重量更新,则可能会很快过度拟合。在这里,您只想以增量方式重新适应预训练的权重。
这是实现整个基本模型的微调的方法:
关于 compile() 和 trainable的重要说明
在模型上调用 **compile()**旨在“冻结”该模型的行为。这意味着在编译模型时,应该在该模型的整个生命周期中保留可训练的属性值,直到再次调用 compile()为止。因此,如果您更改了trainable的值,请考虑再次在模型上调用 compile() 。
有关BatchNormalization层的重要说明
许多图像模型包含BatchNormalization图层。在每一个可以想象的数量上,该层都是一个特例。这里有几件事要牢记。
- BatchNormalization包含2个不可训练的权重,它们在训练过程中会更新。这些是跟踪输入的均值和方差的变量。
- 当设置bn_layer.trainable = False时,BatchNormalization层将以推断模式运行,并且不会更新其均值和方差统计信息。一般而言,对于其他layers不会出现这种情况,如 weight trainability & inference/training modes are two orthogonal concepts所述。但是在BatchNormalization层的情况下,两者是并列的。
- 当您解冻包含BatchNormalization图层的模型以进行微调时,应在调用基本模型时通过传递training = False来使BatchNormalization图层保持推理模式。否则,应用于不可训练权重的更新将突然破坏模型学到的知识。
您将在本手册末尾的端到端示例中看到这种模式。
通过自定义训练循环进行学习和微调(Transfer learning & fine-tuning with a custom training loop)
如果您使用的是自己的低级训练循环,而不是fit(),则工作流程基本保持不变。 在使用梯度更新时,您应注意仅考虑model.trainable_weights属性列表:
这同样适用于微调。
一个端到端的示例:微调猫—狗的图像分类模型
为了巩固这些概念,让我们为您介绍一个具体的端到端传输学习和微调示例。 我们将加载在ImageNet上经过预先训练的Xception模型,并将其用于Kaggle“猫与狗”分类数据集中。
加载数据
首先,让我们使用TFDS来获取cats vs. dogs数据集。 如果您拥有自己的数据集,则可能要使用接口tf.keras.preprocessing.image_dataset_from_directory从磁盘上已归档到特定于类的文件夹中的一组图像中生成相似的标记数据集对象。
在处理非常小的数据集时,迁移学习最有用。 为了使数据集保持较小,我们将使用原始训练数据的40%(25,000张图像)进行训练,将10%用于验证,将10%用于测试。
Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...
Warning:absl:1738 images were corrupted and were skipped
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteIL7NQA/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326
下面是训练数据集中的前9张图像-都是不同的类型。
标签1是狗,标签0是猫。
数据标准化
我们的原始图像有各种尺寸。另外,每个像素由0到255之间的3个整数值(RGB级别值)组成。这不太适合提供神经网络。我们需要做两件事:
- 标准化为固定的图像尺寸。我们选择150x150。
- 归一化介于-1和1之间的像素值。我们将使用归一化层作为模型本身的一部分来进行此操作。
通常,与采用已预处理数据的模型相反,开发以原始数据为输入的模型是一个好习惯。原因是,如果模型需要预处理的数据,则每次导出模型以在其他地方使用它(在Web浏览器中,在移动应用中)时,都需要重新实现完全相同的预处理管道。这很快就变得非常棘手。因此,在达到模型之前,我们应该进行尽可能少的预处理。
在这里,我们将在数据管道中进行图像大小调整(因为深度神经网络只能处理连续的一批数据),并且在创建模型时将其作为模型的一部分进行输入值缩放。
让我们将图像调整为150x150:
此外,让我们分批处理数据并使用缓存和预取来优化加载速度
使用随机数据扩充
当没有大型图像数据集时,通过对训练图像应用随机但逼真的变换(例如随机水平翻转或较小的随机旋转)来人为引入样本多样性是一种很好的做法。 这有助于使模型暴露于训练数据的不同方面,同时减慢过度拟合的速度。
让我们直观地看到经过各种随机转换后的第一批图像:
建立模型
现在让我们建立一个遵循我们先前解释的蓝图的模型。
注意:
- 我们添加了Normalization层以将输入值(最初在[0,255]范围内)缩放到[-1,1]范围。
- 我们在分类层之前添加一个Dropout层,以进行正则化。
- 我们确保在调用基本模型时传递training = False,以使其在推理模式下运行,从而即使取消冻结基本模型进行微调后,batchnorm统计信息也不会得到更新。
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
Model: "functional_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_5 (InputLayer) [(None, 150, 150, 3)] 0
_________________________________________________________________
sequential_3 (Sequential) (None, 150, 150, 3) 0
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3) 7
_________________________________________________________________
xception (Functional) (None, 5, 5, 2048) 20861480
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048) 0
_________________________________________________________________
dropout (Dropout) (None, 2048) 0
_________________________________________________________________
dense_7 (Dense) (None, 1) 2049
=================================================================
Total params: 20,863,536
Trainable params: 2,049
训练顶层参数
Epoch 1/20
291/291 [==============================] - 9s 32ms/step - loss: 0.1758 - binary_accuracy: 0.9226 - val_loss: 0.0897 - val_binary_accuracy: 0.9660
Epoch 2/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1211 - binary_accuracy: 0.9497 - val_loss: 0.0870 - val_binary_accuracy: 0.9686
Epoch 3/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1166 - binary_accuracy: 0.9503 - val_loss: 0.0814 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1125 - binary_accuracy: 0.9534 - val_loss: 0.0825 - val_binary_accuracy: 0.9695
Epoch 5/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1073 - binary_accuracy: 0.9569 - val_loss: 0.0763 - val_binary_accuracy: 0.9703
Epoch 6/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1041 - binary_accuracy: 0.9573 - val_loss: 0.0812 - val_binary_accuracy: 0.9686
Epoch 7/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1023 - binary_accuracy: 0.9567 - val_loss: 0.0820 - val_binary_accuracy: 0.9669
Epoch 8/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1005 - binary_accuracy: 0.9597 - val_loss: 0.0779 - val_binary_accuracy: 0.9695
Epoch 9/20
291/291 [==============================] - 8s 28ms/step - loss: 0.1019 - binary_accuracy: 0.9580 - val_loss: 0.0813 - val_binary_accuracy: 0.9699
Epoch 10/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0940 - binary_accuracy: 0.9651 - val_loss: 0.0762 - val_binary_accuracy: 0.9729
Epoch 11/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0974 - binary_accuracy: 0.9613 - val_loss: 0.0752 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0965 - binary_accuracy: 0.9591 - val_loss: 0.0760 - val_binary_accuracy: 0.9721
Epoch 13/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0962 - binary_accuracy: 0.9598 - val_loss: 0.0785 - val_binary_accuracy: 0.9712
Epoch 14/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0966 - binary_accuracy: 0.9616 - val_loss: 0.0831 - val_binary_accuracy: 0.9699
Epoch 15/20
291/291 [==============================] - 8s 27ms/step - loss: 0.1000 - binary_accuracy: 0.9574 - val_loss: 0.0741 - val_binary_accuracy: 0.9725
Epoch 16/20
291/291 [==============================] - 8s 28ms/step - loss: 0.0940 - binary_accuracy: 0.9628 - val_loss: 0.0781 - val_binary_accuracy: 0.9686
Epoch 17/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0915 - binary_accuracy: 0.9634 - val_loss: 0.0843 - val_binary_accuracy: 0.9678
Epoch 18/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0937 - binary_accuracy: 0.9620 - val_loss: 0.0829 - val_binary_accuracy: 0.9669
Epoch 19/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0988 - binary_accuracy: 0.9601 - val_loss: 0.0862 - val_binary_accuracy: 0.9686
Epoch 20/20
291/291 [==============================] - 8s 27ms/step - loss: 0.0928 - binary_accuracy: 0.9644 - val_loss: 0.0798 - val_binary_accuracy: 0.9703
<tensorflow.python.keras.callbacks.History at 0x7f6104f04518>
对整个模型进行一轮微调
最后,让我们解冻基本模型并以较低的学习率端到端地训练整个模型。
重要的是,尽管基本模型变得可训练,但由于我们在构建模型时调用该模型时传递了training = False,因此它仍在推理模式下运行。 这意味着内部的批处理规范化层不会更新其批处理统计信息。 如果这样做的话,他们将破坏迄今为止该模型所学习的表示形式。
Model: "functional_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_5 (InputLayer) [(None, 150, 150, 3)] 0
_________________________________________________________________
sequential_3 (Sequential) (None, 150, 150, 3) 0
_________________________________________________________________
normalization (Normalization (None, 150, 150, 3) 7
_________________________________________________________________
xception (Functional) (None, 5, 5, 2048) 20861480
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048) 0
_________________________________________________________________
dropout (Dropout) (None, 2048) 0
_________________________________________________________________
dense_7 (Dense) (None, 1) 2049
=================================================================
Total params: 20,863,536
Trainable params: 20,809,001
Non-trainable params: 54,535
_________________________________________________________________
Epoch 1/10
2/291 [..............................] - ETA: 17s - loss: 0.1439 - binary_accuracy: 0.9219WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0329s vs `on_train_batch_end` time: 0.0905s). Check your callbacks.
Warning:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0329s vs `on_train_batch_end` time: 0.0905s). Check your callbacks.
291/291 [==============================] - 38s 132ms/step - loss: 0.0786 - binary_accuracy: 0.9706 - val_loss: 0.0631 - val_binary_accuracy: 0.9772
Epoch 2/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0553 - binary_accuracy: 0.9790 - val_loss: 0.0537 - val_binary_accuracy: 0.9781
Epoch 3/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0442 - binary_accuracy: 0.9829 - val_loss: 0.0532 - val_binary_accuracy: 0.9819
Epoch 4/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0369 - binary_accuracy: 0.9858 - val_loss: 0.0460 - val_binary_accuracy: 0.9832
Epoch 5/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0335 - binary_accuracy: 0.9870 - val_loss: 0.0561 - val_binary_accuracy: 0.9794
Epoch 6/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0253 - binary_accuracy: 0.9910 - val_loss: 0.0559 - val_binary_accuracy: 0.9819
Epoch 7/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0232 - binary_accuracy: 0.9920 - val_loss: 0.0432 - val_binary_accuracy: 0.9845
Epoch 8/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0185 - binary_accuracy: 0.9930 - val_loss: 0.0396 - val_binary_accuracy: 0.9854
Epoch 9/10
291/291 [==============================] - 38s 129ms/step - loss: 0.0147 - binary_accuracy: 0.9948 - val_loss: 0.0439 - val_binary_accuracy: 0.9832
Epoch 10/10
291/291 [==============================] - 37s 129ms/step - loss: 0.0117 - binary_accuracy: 0.9954 - val_loss: 0.0538 - val_binary_accuracy: 0.9819
<tensorflow.python.keras.callbacks.History at 0x7f611c26e438>
经过10个epoch后,微调后的模型得到了显著的提升。
翻译自tesnorflow官方社区手册,主要内容为机翻,有不足之处,欢迎各位朋友在评论中指出。