06-迁移学习:用基于 ImageNet 训练的权重的 MobileNet V2 模型进行猫狗分类


迁移学习把之前学到的知识,应用到新的问题上,给出较好的解决方法。

本文的例子是基于模型的迁移,就是重新利用模型里的参数。该类方法在神经网络里面用的特别多,因为神经网络的结构可以直接进行迁移。比如大家熟知的 finetune 就是模型参数迁移的很好的体现。

关于迁移学习更多的资料可参考:

  1. https://blog.csdn.net/epubit17/article/details/110390339
  2. https://blog.csdn.net/qq_42951560/article/details/110244616

1. 数据预处理

1.1 数据下载

在这里,我们将使用包含数千个猫和狗图像的数据集。下载并解压缩包含图像的 zip 文件,然后使用 tf.keras.preprocessing.image_dataset_from_directory 效用函数创建一个 tf.data.Dataset 进行训练和验证。

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
print(os.path.dirname(path_to_zip))  # /public/home/zhaiyuxin/.keras/datasets

在这里插入图片描述

# 显示训练集中的前9个图像和标签
class_names = train_dataset.class_names

plt.figure(figsize=(10, 10))
# dataset.take(1):取第一个元素构建dataset(是第一个元素,不是随机的一个)
# 从文件中读取数据形成train_dataset时是以为9为一个步长的,故这里的dataset.take(1)即前9个数据。
for images, labels in train_dataset.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i+1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")
plt.show()
# print(train_dataset.take(1))  # <TakeDataset shapes: ((None, 160, 160, 3), (None,)), types: (tf.float32, tf.int32)>

在这里插入图片描述
配置数据集以提高性能:使用缓冲预提取从磁盘加载图像,以免造成 I/O 阻塞。

AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)  # prefetch: 数据准备和参数迭代并行执行
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

在这里我们使用数据扩充减少过拟合。当我们没有较大的图像数据集时,最好将随机但现实的转换应用于训练图像(例如旋转或水平翻转)来人为引入样本多样性。这有助于使模型暴露于训练数据的不同方面并减少过拟合。

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
  tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

注:当您调用 model.fit 时,这些层仅在训练过程中才会处于有效状态。在 model.evaulate 或 model.fit 中的推断模式下使用模型时,它们处于停用状态。

我们将数据扩充重复应用于同一张图像查看其效果:

for image, _ in train_dataset.take(1):
  plt.figure(figsize=(10, 10))
  first_image = image[0]
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    augmented_image = data_augmentation(tf.expand_dims(first_image, 0))  # tf.expand_dims在axis=0轴给input增加一个维度
    plt.imshow(augmented_image[0] / 255)
    plt.axis('off')
plt.show()

效果图如下所示:
在这里插入图片描述
之后,我们将使用tf.keras.applications.MobileNetV2 作为基础模型。此模型期望像素值处于 [-1, 1] 范围内,但此时,图像中的像素值处于 [0, 255] 范围内。要重新缩放这些像素值,我们要使用模型随附的预处理方法。

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

2. 从预训练卷积网络创建基础模型

我们将根据 Google 开发的 MobileNet V2 模型来创建基础模型。此模型已基于 ImageNet 数据集进行预训练,ImageNet 数据集是一个包含 140 万个图像和 1000 个类的大型数据集。ImageNet 是一个研究训练数据集,具有各种各样的类别,例如 jackfruitsyringe。此知识库将帮助我们对特定数据集中的猫和狗进行分类。

首先,您需要选择将 MobileNet V2 的哪一层用于特征提取。最后的分类层(在“顶部”,因为大多数机器学习模型的图表是从下到上的)不是很有用。相反,您将按照常见做法依赖于展平操作之前的最后一层。此层被称为“瓶颈层”。与最后一层/顶层相比,瓶颈层的特征保留了更多的通用性。

首先,实例化一个已预加载基于 ImageNet 训练的权重的 MobileNet V2 模型。通过指定 include_top=False 参数,可以加载不包括顶部分类层的网络,这对于特征提取十分理想。

Keras的预训练权值模型用来进行预测、特征提取和微调,可用的模型有Xception、VGG16、ResNet50、MoblieNetV2等,这些模型的使用示例可以参考:https://blog.csdn.net/weixin_39506322/article/details/88640679

IMG_SHAPE = IMG_SIZE + (3,)  # (160, 160, 3)

base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,  #
                                               include_top=False,  # 不包括顶层的全连接层
                                               weights='imagenet')  # 'imagenet' 代表加载在 ImageNet 上预训练的权值。

在这里插入图片描述

此特征提取程序将每个 160x160x3 图像转换为 5x5x1280 的特征块。我们看看它对一批示例图像做了些什么:

image_batch, label_batch = next(iter(train_dataset))
print(image_batch.shape)  # (32, 160, 160, 3)
feature_batch = base_model(image_batch)
print(feature_batch.shape)  # (32, 5, 5, 1280)

3. 特征提取

3.1 冻结卷积基

在编译和训练模型之前,冻结卷积基至关重要。冻结(通过设置 layer.trainable = False)可避免在训练期间更新给定层中的权重。MobileNet V2 具有许多层,因此将整个模型的 trainable 标记设置为 False 会冻结所有这些层。

base_model.trainable = False

我们通过base_model.summary()查看模型的结构:

Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 80, 80, 32)   864         input_1[0][0]                    
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 80, 80, 32)   0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 80, 80, 32)   288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 80, 80, 32)   128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 80, 80, 32)   0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 80, 80, 16)   512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 80, 80, 16)   64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 80, 80, 96)   1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 80, 80, 96)   384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 80, 80, 96)   0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 81, 81, 96)   0           block_1_expand_relu[0][0]        
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 40, 40, 96)   864         block_1_pad[0][0]                
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 40, 40, 96)   384         block_1_depthwise[0][0]          
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU)   (None, 40, 40, 96)   0           block_1_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_1_project (Conv2D)        (None, 40, 40, 24)   2304        block_1_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 40, 40, 24)   96          block_1_project[0][0]            
__________________________________________________________________________________________________
block_2_expand (Conv2D)         (None, 40, 40, 144)  3456        block_1_project_BN[0][0]         
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_2_expand[0][0]             
__________________________________________________________________________________________________
block_2_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_2_expand_BN[0][0]          
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 40, 40, 144)  1296        block_2_expand_relu[0][0]        
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 40, 40, 144)  576         block_2_depthwise[0][0]          
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU)   (None, 40, 40, 144)  0           block_2_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_2_project (Conv2D)        (None, 40, 40, 24)   3456        block_2_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, 40, 40, 24)   96          block_2_project[0][0]            
__________________________________________________________________________________________________
block_2_add (Add)               (None, 40, 40, 24)   0           block_1_project_BN[0][0]         
                                                                 block_2_project_BN[0][0]         
__________________________________________________________________________________________________
block_3_expand (Conv2D)         (None, 40, 40, 144)  3456        block_2_add[0][0]                
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_3_expand[0][0]             
__________________________________________________________________________________________________
block_3_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_3_expand_BN[0][0]          
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D)     (None, 41, 41, 144)  0           block_3_expand_relu[0][0]        
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, 20, 20, 144)  1296        block_3_pad[0][0]                
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, 20, 20, 144)  576         block_3_depthwise[0][0]          
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU)   (None, 20, 20, 144)  0           block_3_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_3_project (Conv2D)        (None, 20, 20, 32)   4608        block_3_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, 20, 20, 32)   128         block_3_project[0][0]            
__________________________________________________________________________________________________
block_4_expand (Conv2D)         (None, 20, 20, 192)  6144        block_3_project_BN[0][0]         
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_4_expand[0][0]             
__________________________________________________________________________________________________
block_4_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_4_expand_BN[0][0]          
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_4_expand_relu[0][0]        
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_4_depthwise[0][0]          
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_4_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_4_project (Conv2D)        (None, 20, 20, 32)   6144        block_4_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, 20, 20, 32)   128         block_4_project[0][0]            
__________________________________________________________________________________________________
block_4_add (Add)               (None, 20, 20, 32)   0           block_3_project_BN[0][0]         
                                                                 block_4_project_BN[0][0]         
__________________________________________________________________________________________________
block_5_expand (Conv2D)         (None, 20, 20, 192)  6144        block_4_add[0][0]                
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_5_expand[0][0]             
__________________________________________________________________________________________________
block_5_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_5_expand_BN[0][0]          
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_5_expand_relu[0][0]        
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_5_depthwise[0][0]          
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_5_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_5_project (Conv2D)        (None, 20, 20, 32)   6144        block_5_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, 20, 20, 32)   128         block_5_project[0][0]            
__________________________________________________________________________________________________
block_5_add (Add)               (None, 20, 20, 32)   0           block_4_add[0][0]                
                                                                 block_5_project_BN[0][0]         
__________________________________________________________________________________________________
block_6_expand (Conv2D)         (None, 20, 20, 192)  6144        block_5_add[0][0]                
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_6_expand[0][0]             
__________________________________________________________________________________________________
block_6_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_6_expand_BN[0][0]          
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D)     (None, 21, 21, 192)  0           block_6_expand_relu[0][0]        
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, 10, 10, 192)  1728        block_6_pad[0][0]                
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, 10, 10, 192)  768         block_6_depthwise[0][0]          
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU)   (None, 10, 10, 192)  0           block_6_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_6_project (Conv2D)        (None, 10, 10, 64)   12288       block_6_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, 10, 10, 64)   256         block_6_project[0][0]            
__________________________________________________________________________________________________
block_7_expand (Conv2D)         (None, 10, 10, 384)  24576       block_6_project_BN[0][0]         
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_7_expand[0][0]             
__________________________________________________________________________________________________
block_7_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_7_expand_BN[0][0]          
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_7_expand_relu[0][0]        
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_7_depthwise[0][0]          
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_7_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_7_project (Conv2D)        (None, 10, 10, 64)   24576       block_7_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, 10, 10, 64)   256         block_7_project[0][0]            
__________________________________________________________________________________________________
block_7_add (Add)               (None, 10, 10, 64)   0           block_6_project_BN[0][0]         
                                                                 block_7_project_BN[0][0]         
__________________________________________________________________________________________________
block_8_expand (Conv2D)         (None, 10, 10, 384)  24576       block_7_add[0][0]                
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_8_expand[0][0]             
__________________________________________________________________________________________________
block_8_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_8_expand_BN[0][0]          
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_8_expand_relu[0][0]        
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_8_depthwise[0][0]          
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_8_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_8_project (Conv2D)        (None, 10, 10, 64)   24576       block_8_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, 10, 10, 64)   256         block_8_project[0][0]            
__________________________________________________________________________________________________
block_8_add (Add)               (None, 10, 10, 64)   0           block_7_add[0][0]                
                                                                 block_8_project_BN[0][0]         
__________________________________________________________________________________________________
block_9_expand (Conv2D)         (None, 10, 10, 384)  24576       block_8_add[0][0]                
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_9_expand[0][0]             
__________________________________________________________________________________________________
block_9_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_9_expand_BN[0][0]          
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_9_expand_relu[0][0]        
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_9_depthwise[0][0]          
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_9_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_9_project (Conv2D)        (None, 10, 10, 64)   24576       block_9_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, 10, 10, 64)   256         block_9_project[0][0]            
__________________________________________________________________________________________________
block_9_add (Add)               (None, 10, 10, 64)   0           block_8_add[0][0]                
                                                                 block_9_project_BN[0][0]         
__________________________________________________________________________________________________
block_10_expand (Conv2D)        (None, 10, 10, 384)  24576       block_9_add[0][0]                
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, 10, 10, 384)  1536        block_10_expand[0][0]            
__________________________________________________________________________________________________
block_10_expand_relu (ReLU)     (None, 10, 10, 384)  0           block_10_expand_BN[0][0]         
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, 10, 10, 384)  3456        block_10_expand_relu[0][0]       
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, 10, 10, 384)  1536        block_10_depthwise[0][0]         
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           block_10_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_10_project (Conv2D)       (None, 10, 10, 96)   36864       block_10_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, 10, 10, 96)   384         block_10_project[0][0]           
__________________________________________________________________________________________________
block_11_expand (Conv2D)        (None, 10, 10, 576)  55296       block_10_project_BN[0][0]        
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_11_expand[0][0]            
__________________________________________________________________________________________________
block_11_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_11_expand_BN[0][0]         
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_11_expand_relu[0][0]       
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_11_depthwise[0][0]         
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_11_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_11_project (Conv2D)       (None, 10, 10, 96)   55296       block_11_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, 10, 10, 96)   384         block_11_project[0][0]           
__________________________________________________________________________________________________
block_11_add (Add)              (None, 10, 10, 96)   0           block_10_project_BN[0][0]        
                                                                 block_11_project_BN[0][0]        
__________________________________________________________________________________________________
block_12_expand (Conv2D)        (None, 10, 10, 576)  55296       block_11_add[0][0]               
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_12_expand[0][0]            
__________________________________________________________________________________________________
block_12_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_12_expand_BN[0][0]         
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_12_expand_relu[0][0]       
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_12_depthwise[0][0]         
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_12_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_12_project (Conv2D)       (None, 10, 10, 96)   55296       block_12_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, 10, 10, 96)   384         block_12_project[0][0]           
__________________________________________________________________________________________________
block_12_add (Add)              (None, 10, 10, 96)   0           block_11_add[0][0]               
                                                                 block_12_project_BN[0][0]        
__________________________________________________________________________________________________
block_13_expand (Conv2D)        (None, 10, 10, 576)  55296       block_12_add[0][0]               
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_13_expand[0][0]            
__________________________________________________________________________________________________
block_13_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_13_expand_BN[0][0]         
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D)    (None, 11, 11, 576)  0           block_13_expand_relu[0][0]       
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, 5, 5, 576)    5184        block_13_pad[0][0]               
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, 5, 5, 576)    2304        block_13_depthwise[0][0]         
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU)  (None, 5, 5, 576)    0           block_13_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_13_project (Conv2D)       (None, 5, 5, 160)    92160       block_13_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, 5, 5, 160)    640         block_13_project[0][0]           
__________________________________________________________________________________________________
block_14_expand (Conv2D)        (None, 5, 5, 960)    153600      block_13_project_BN[0][0]        
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_14_expand[0][0]            
__________________________________________________________________________________________________
block_14_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_14_expand_BN[0][0]         
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_14_expand_relu[0][0]       
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_14_depthwise[0][0]         
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_14_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_14_project (Conv2D)       (None, 5, 5, 160)    153600      block_14_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, 5, 5, 160)    640         block_14_project[0][0]           
__________________________________________________________________________________________________
block_14_add (Add)              (None, 5, 5, 160)    0           block_13_project_BN[0][0]        
                                                                 block_14_project_BN[0][0]        
__________________________________________________________________________________________________
block_15_expand (Conv2D)        (None, 5, 5, 960)    153600      block_14_add[0][0]               
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_15_expand[0][0]            
__________________________________________________________________________________________________
block_15_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_15_expand_BN[0][0]         
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_15_expand_relu[0][0]       
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_15_depthwise[0][0]         
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_15_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_15_project (Conv2D)       (None, 5, 5, 160)    153600      block_15_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, 5, 5, 160)    640         block_15_project[0][0]           
__________________________________________________________________________________________________
block_15_add (Add)              (None, 5, 5, 160)    0           block_14_add[0][0]               
                                                                 block_15_project_BN[0][0]        
__________________________________________________________________________________________________
block_16_expand (Conv2D)        (None, 5, 5, 960)    153600      block_15_add[0][0]               
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_16_expand[0][0]            
__________________________________________________________________________________________________
block_16_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_16_expand_BN[0][0]         
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_16_expand_relu[0][0]       
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_16_depthwise[0][0]         
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_16_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, 5, 5, 320)    307200      block_16_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 5, 5, 320)    1280        block_16_project[0][0]           
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, 5, 5, 1280)   409600      block_16_project_BN[0][0]        
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 5, 5, 1280)   5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 5, 5, 1280)   0           Conv_1_bn[0][0]                  
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984
__________________________________________________________________________________________________
3.2 添加分类头

要从特征块生成预测,请使用 tf.keras.layers.GlobalAveragePooling2D 层在 5x5 空间位置内取平均值,以将特征转换成每个图像一个向量(包含 1280 个元素)。

# 将特征转换成每个图像一个向量(包含1280个元素)
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
# print(feature_batch_average.shape)  # (32, 1280)

应用 tf.keras.layers.Dense 层将这些特征转换成每个图像一个预测。您在此处不需要激活函数,因为此预测将被视为 logit 或原始预测值。正数预测 1 类,负数预测 0 类。

prediction_layer = tf.keras.layers.Dense(1)  # 用Dense层将这些特征转换成每个图像一个预测
prediction_batch = prediction_layer(feature_batch_average)
# print(prediction_batch.shape)  # (32, 1)

通过使用 Keras 函数式 API 将数据扩充、重新缩放、base_model 和特征提取程序层链接在一起来构建模型。如前面所述,由于我们的模型包含 BatchNormalization 层,因此请使用 training = False。(设置 layer.trainable = False 时,BatchNormalization 层将以推断模式运行,并且不会更新其均值和方差统计信息。**解冻包含 BatchNormalization 层的模型以进行微调时,应在调用基础模型时通过传递 training = False 来使 BatchNormalization 层保持在推断模式下。**否则,应用于不可训练权重的更新将破坏模型已经学习到的内容。)

# 用Keras函数式API将数据扩充、重新缩放、base_model和特征提取程序层链接在一起来构建模型
inputs = tf.keras.Input(shape=(160, 160, 3))  # 统一输入尺寸
x = data_augmentation(inputs)  # 数据增强
x = preprocess_input(x)  # 输入预处理
x = base_model(x, training=False)  # 由于我们的模型包含 BatchNormalization 层,因此使用 training = False
x = global_average_layer(x)  # 转换为每个图像一个向量
x = tf.keras.layers.Dropout(0.2)(x)  # 使用Dropout
outputs = prediction_layer(x)  # 预测输出值
model = tf.keras.Model(inputs, outputs)

3. 编译模型

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),  # 输出层会做normalization(softmax)
              metrics=['accuracy'])
# model.summary()

模型结果如下所示:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

MobileNet 中的 250 万个参数被冻结,但在密集层中有 1200 个可训练参数。它们分为两个 tf.Variable 对象,即权重和偏差。

print(len(model.trainable_variables))  # 2

4. 训练模型

经过 10 个周期的训练后,您应该在验证集上看到约 95% 的准确率。

loss0, acc0 = model.evaluate(validation_dataset)
# print("initial loss: {:.2f}".format(loss0))
# print("initial accuracy: {:.2f}".format(acc0))

其输出结果为:

26/26 [==============================] - 3s 69ms/step - loss: 0.9336 - accuracy: 0.4220
initial loss: 0.93
initial accuracy: 0.42

初始模型训练的准确度为42%,经过10个epochs后,我们观察其效果:

initial_epochs = 10
history = model.fit(train_dataset,
                    epochs=initial_epochs,
                    validation_data=validation_dataset)

我们可以看到准确率提升至95%左右。
在这里插入图片描述
我们用plt画出使用 MobileNet V2 基础模型作为固定特征提取程序时训练和验证准确率/损失的学习曲线。

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()), 1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0, 1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

其曲线如下所示:
在这里插入图片描述
验证指标明显优于训练指标,主要原因是 tf.keras.layers.BatchNormalizationtf.keras.layers.Dropout 等层会影响训练期间的准确率。在计算验证损失时,它们处于关闭状态。

在较小程度上,这也是因为训练指标报告的是某个周期的平均值,而验证指标则在经过该周期后才进行评估,因此验证指标会看到训练时间略长一些的模型。

5. 微调

在之前的特征提取实验中,我们仅在 MobileNet V2 基础模型的顶部训练了一些层。预训练网络的权重在训练过程中未更新。

**进一步提高性能的一种方式是在训练(或“微调”)预训练模型顶层的权重的同时,另外训练您添加的分类器。**训练过程将强制权重从通用特征映射调整为专门与数据集相关联的特征。

注:只有在使用设置为不可训练的预训练模型训练顶级分类器之后,才能尝试这样做。如果在预训练模型的顶部添加一个随机初始化的分类器并尝试共同训练所有层,则梯度更新的幅度将过大(由于分类器的随机权重所致),这将导致预训练模型忘记它已经学习的内容。

另外,还应尝试微调少量顶层而不是整个 MobileNet 模型。**在大多数卷积网络中,层越高,它的专门程度就越高。前几层学习非常简单且通用的特征,这些特征可以泛化到几乎所有类型的图像。**随着您向上层移动,这些特征越来越特定于训练模型所使用的数据集。微调的目标是使这些专用特征适应新的数据集,而不是覆盖通用学习。

5.1 解冻模型的顶层

解冻 base_model 并将底层设置为不可训练。随后重新编译模型(使这些更改生效的必需操作),然后恢复训练。

# 解冻模型的顶层
base_model.trainable = True

# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False

base model的层数为154。

Number of layers in the base model:  154
5.2 编译模型

当我们正在**训练一个大得多的模型并且想要重新调整预训练权重时,在此阶段需使用较低的学习率。**否则,模型可能会很快过拟合。

print("-------------------------Fine Tuning-------------------------")
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate / 10),  # 在训练一个大得多的模型并且想要重新调整预训练权重时使用较低的学习率。
              metrics=['accuracy'])

model.summary()查看此时的模型结构:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,862,721
Non-trainable params: 396,544
_________________________________________________________________

我们看一下现在模型可训练的变量有多少?

print(len(model.trainable_variables))  # 56
5.3 继续训练模型

如果您已提前训练至收敛,则此步骤将使您的准确率提高几个百分点。

fine_tune_epochs = 10
total_epochs = initial_epochs + fine_tune_epochs

history_fine = model.fit(train_dataset,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset)

经过微调后,模型在验证集上的准确率几乎达到 98%。
在这里插入图片描述
在这里插入图片描述

在微调 MobileNet V2 基础模型的最后几层并在这些层上训练分类器时,我们来看一下训练和验证准确率/损失的学习曲线。验证损失比训练损失高得多,因此可能存在一些过拟合。

# 将两次训练的学习曲线连起来作图
acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs - 1, initial_epochs - 1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs - 1, initial_epochs - 1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

在这里插入图片描述

当新的训练集相对较小且与原始 MobileNet V2 数据集相似时,也可能存在一些过拟合。

6. 评估和预测

最后,您可以使用测试集在新数据上验证模型的性能。

loss, accuracy = model.evaluate(test_dataset)
# 6/6 [==============================] - 1s 79ms/step - loss: 0.0157 - accuracy: 0.9948

现在,我们可以使用此模型来预测您的宠物是猫还是狗。

#Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()  # as_numpy_iterator()为分批次batch操作
predictions = model.predict_on_batch(image_batch).flatten()

# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)  # 值<0.5输出0,值>0.5输出1

print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)

结果为:

Predictions:
 [0 0 0 1 0 0 1 1 1 1 0 1 1 1 1 0 1 1 0 1 1 1 1 1 0 0 0 0 0 0 1 0]
Labels:
 [0 0 0 1 0 0 1 1 1 1 0 1 1 1 1 0 1 1 0 1 1 1 1 1 0 0 0 0 0 0 1 0]

我们用标签展示其结果,验证预测是否正确:

plt.figure(figsize=(10, 10))
for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image_batch[i].astype("uint8"))
    plt.title(class_names[predictions[i]])
    plt.axis("off")
plt.show()

在这里插入图片描述

7. 总结

**使用预训练模型进行特征提取:**使用小型数据集时,常见做法是利用基于相同域中的较大数据集训练的模型所学习的特征。==为此,您需要实例化预训练模型并在顶部添加一个全连接分类器。预训练模型处于“冻结状态”,训练过程中仅更新分类器的权重。==在这种情况下,卷积基提取了与每个图像关联的所有特征,而您刚刚训练了一个根据给定的提取特征集确定图像类的分类器。

**微调预训练模型:**为了进一步提高性能,可能需要通过微调将预训练模型的顶层重新用于新的数据集。==在本例中,您调整了权重,以使模型学习特定于数据集的高级特征。==当训练数据集较大且与训练预训练模型所使用的原始数据集非常相似时,通常建议使用这种技术。

  • 6
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: 非常感谢您的提问。关于使用Python基于迁移学习训练一个模型的问题,我可以回答。 使用Python进行迁移学习训练一个模型的步骤如下: 1. 选择一个预训练模型,如VGG、ResNet或Inception等。 2. 用预训练模型作为特征提取器,提取输入数据集的特征。 3. 将提取的特征输入到一个新的全连接层中,用于分类或回归。 4. 对新的全连接层进行训练,更新权重参数。 5. 对整个模型进行微调,包括预训练模型权重和新的全连接层的权重。 6. 用测试数据集对模型进行评估,调整模型的超参数和训练参数,直到达到最佳性能。 以上是使用Python基于迁移学习训练一个模型的基本步骤。具体实现过程中,需要根据具体问题和数据集进行调整和优化。希望对您有所帮助。 ### 回答2: 使用Python利用迁移学习训练一个模型可以通过以下步骤进行: 1. 导入所需的Python库,如TensorFlow和Keras等。这些库提供了训练和构建模型所需的功能和工具。 2. 下载预训练模型权重。预训练模型通常是在大型数据集上进行训练后得到的,具有良好的特征提取能力。可以从TensorFlow和Keras的官方网站下载这些模型权重。 3. 创建模型。使用Keras或TensorFlow等库创建一个模型。可以选择使用预训练模型的全部网络结构,也可以根据需要对其进行调整。 4. 设置迁移学习的方式。迁移学习可以通过冻结预训练模型的一部分或全部层来进行。冻结的层不会在训练过程中更新权重,而是保持原有的特征提取能力。可以根据任务需求选择合适的层进行冻结。 5. 设置自定义的输出层。根据要解决的具体问题,添加适当的自定义输出层。输出层的结构和神经元数量通常根据数据集和任务类型进行调整。 6. 编译和训练模型。编译模型需要设置损失函数、优化器和评估指标等。然后,使用数据集对模型进行训练。可以根据需要设置训练的批次大小、迭代次数和学习率等参数。 7. 进行模型评估和预测。使用测试集对训练好的模型进行评估,计算模型的准确率、损失值等指标。然后,使用模型进行预测,得出对新样本的分类结果。 8. 进行模型微调(可选)。根据实际情况,可以对模型进行微调,以进一步提高模型性能。可以解冻一些层进行训练,并根据需要进行调整。 9. 保存模型。将训练好的模型保存到硬盘上,以便在需要时进行加载和使用。 使用Python进行迁移学习训练模型可以简化模型构建的过程,并节省大量的训练时间。通过利用预训练模型的特征提取能力,可以在小规模数据集上实现高效的训练和预测。同时,Python提供了丰富的工具和库,使得迁移学习训练模型的过程更加方便和灵活。 ### 回答3: 基于迁移学习使用Python训练模型可以大大加快模型训练的速度和提高模型的准确性。迁移学习是指将已经在大规模数据集上训练好的深度学习模型的参数、网络架构等迁移到一个新的任务上进行训练。 首先,在Python中使用深度学习框架(如TensorFlow、PyTorch等)加载预训练好的模型。这些模型通常是在大规模数据集上进行训练得到的,如ImageNet数据集。可以使用框架提供的函数加载模型的参数,并创建一个新的模型结构。 接下来,冻结预训练模型的参数,即将这些参数设置为不可训练。这样做是因为预训练模型已经在大规模数据集上训练得到了较好的特征提取能力,我们只需要在新的任务上微调这些特征。 然后,在新的任务上构建新的模型结构,一般需要去掉原模型的输出层,并添加新的层来适应新的任务。根据新任务的要求,可以选择添加全连接层、卷积层或其他类型的层。 在构建新的模型结构后,使用Python编写代码进行模型训练。这包括指定损失函数、优化算法、学习率等超参数,并使用新的数据集进行训练。可以根据需要调整超参数,使用训练集和验证集来监控模型的性能,并进行适当的调整。 最后,使用训练好的模型在测试集或实际应用中进行评估。可以通过计算准确率、召回率、F1得分等指标来评估模型的性能。 总之,通过使用Python进行迁移学习,我们可以充分利用已有的预训练模型,快速训练一个适应新任务的模型。这种方法不仅可以节省数据集和计算资源的成本,还可以提高模型的准确性和效率。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值