机器学习18:用Keras实现迁移学习方法,原理

迁移学习是指对提前训练过的神经网络进行调整,以用于新的不同数据集。

1.Keras实现迁移学习实例

先看个Keras实现的迁移学习案例
https://keras-cn.readthedocs.io/en/latest/other/application/


import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import pandas as pd

import sklearn
import os
import random
import csv

import tensorflow as tf

from numpy import expand_dims
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from PIL import Image

from keras.preprocessing import image
import keras as K

1.1准备训练数据

BATCH_SIZE = 16

# 迭代50次
EPOCHS = 50

# 依照模型规定,图片大小被设定为224
IMAGE_SIZE = 224

TRAIN_PATH = './17flowerclasses/train'
TEST_PATH = './17flowerclasses/test'
#FLOWER_CLASSES = ['Bluebell', 'ButterCup', 'ColtsFoot', 'Cowslip', 'Crocus', 'Daffodil', 'Daisy','Dandelion', 'Fritillary', 'Iris', 'LilyValley', 'Pansy', 'Snowdrop', 'Sunflower','Tigerlily', 'tulip', 'WindFlower']

# 使用数据增强
train_datagen  = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')
# 可指定输出图片大小,因为深度学习要求训练图片大小保持一致
train_generator = train_datagen.flow_from_directory(directory=TRAIN_PATH,
                                                        target_size=(IMAGE_SIZE, IMAGE_SIZE),
                                                        batch_size = BATCH_SIZE)#,
                                                        #classes=FLOWER_CLASSES)
test_datagen = ImageDataGenerator()
test_generator = test_datagen.flow_from_directory(directory=TEST_PATH, 
                                                      target_size=(IMAGE_SIZE, IMAGE_SIZE))#,
                                                      #classes=FLOWER_CLASSES)
Found 1190 images belonging to 17 classes.
Found 170 images belonging to 17 classes.

1.2设计迁移学习网络

这里运用resnet50模型,把输出层改为链接两个全连接层,组成新的网络。

import keras 
from keras.models import Model, Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout,GlobalAveragePooling2D

IMAGE_SIZE = 224
# weights='imagenet'
keras_resnet = keras.applications.resnet50.ResNet50(include_top=False, weights='imagenet',input_shape = (IMAGE_SIZE, IMAGE_SIZE, 3))

print(keras_resnet.output_shape[1:])

resnet_flower = Sequential()
resnet_flower.add(Flatten(input_shape=keras_resnet.output_shape[1:]))
resnet_flower.add(Dense(1024, activation="relu"))
resnet_flower.add(Dropout(0.5))
resnet_flower.add(Dense(17, activation='softmax'))

resnet_flower_model = Model(inputs=keras_resnet.input, outputs=resnet_flower(keras_resnet.output))
resnet_flower_model.summary()
/home/leon/anaconda3/lib/python3.7/site-packages/keras_applications/resnet50.py:265: UserWarning: The output shape of `ResNet50(include_top=False)` has been changed since Keras 2.2.0.
  warnings.warn('The output shape of `ResNet50(include_top=False)` '


(7, 7, 2048)
Model: "model_9"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_11 (InputLayer)           (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_11[0][0]                   
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 112, 112, 64) 256         conv1[0][0]                      
__________________________________________________________________________________________________
activation_189 (Activation)     (None, 112, 112, 64) 0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D)       (None, 114, 114, 64) 0           activation_189[0][0]             
__________________________________________________________________________________________________
max_pooling2d_9 (MaxPooling2D)  (None, 56, 56, 64)   0           pool1_pad[0][0]                  
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, 56, 56, 64)   4160        max_pooling2d_9[0][0]            
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 56, 56, 64)   256         res2a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_190 (Activation)     (None, 56, 56, 64)   0           bn2a_branch2a[0][0]              
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, 56, 56, 64)   36928       activation_190[0][0]             
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 56, 56, 64)   256         res2a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_191 (Activation)     (None, 56, 56, 64)   0           bn2a_branch2b[0][0]              
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, 56, 56, 256)  16640       activation_191[0][0]             
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, 56, 56, 256)  16640       max_pooling2d_9[0][0]            
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 56, 56, 256)  1024        res2a_branch2c[0][0]             
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 56, 56, 256)  1024        res2a_branch1[0][0]              
__________________________________________________________________________________________________
add_1 (Add)                     (None, 56, 56, 256)  0           bn2a_branch2c[0][0]              
                                                                 bn2a_branch1[0][0]               
__________________________________________________________________________________________________
activation_192 (Activation)     (None, 56, 56, 256)  0           add_1[0][0]                      
__________________________________________________________________________________________________
res2b_branch2a (Conv2D)         (None, 56, 56, 64)   16448       activation_192[0][0]             
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 56, 56, 64)   256         res2b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_193 (Activation)     (None, 56, 56, 64)   0           bn2b_branch2a[0][0]              
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, 56, 56, 64)   36928       activation_193[0][0]             
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 56, 56, 64)   256         res2b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_194 (Activation)     (None, 56, 56, 64)   0           bn2b_branch2b[0][0]              
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, 56, 56, 256)  16640       activation_194[0][0]             
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 56, 56, 256)  1024        res2b_branch2c[0][0]             
__________________________________________________________________________________________________
add_2 (Add)                     (None, 56, 56, 256)  0           bn2b_branch2c[0][0]              
                                                                 activation_192[0][0]             
__________________________________________________________________________________________________
activation_195 (Activation)     (None, 56, 56, 256)  0           add_2[0][0]                      
__________________________________________________________________________________________________
res2c_branch2a (Conv2D)         (None, 56, 56, 64)   16448       activation_195[0][0]             
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 56, 56, 64)   256         res2c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_196 (Activation)     (None, 56, 56, 64)   0           bn2c_branch2a[0][0]              
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, 56, 56, 64)   36928       activation_196[0][0]             
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 56, 56, 64)   256         res2c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_197 (Activation)     (None, 56, 56, 64)   0           bn2c_branch2b[0][0]              
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, 56, 56, 256)  16640       activation_197[0][0]             
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 56, 56, 256)  1024        res2c_branch2c[0][0]             
__________________________________________________________________________________________________
add_3 (Add)                     (None, 56, 56, 256)  0           bn2c_branch2c[0][0]              
                                                                 activation_195[0][0]             
__________________________________________________________________________________________________
activation_198 (Activation)     (None, 56, 56, 256)  0           add_3[0][0]                      
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, 28, 28, 128)  32896       activation_198[0][0]             
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_199 (Activation)     (None, 28, 28, 128)  0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_199[0][0]             
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_200 (Activation)     (None, 28, 28, 128)  0           bn3a_branch2b[0][0]              
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_200[0][0]             
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, 28, 28, 512)  131584      activation_198[0][0]             
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3a_branch2c[0][0]             
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512)  2048        res3a_branch1[0][0]              
__________________________________________________________________________________________________
add_4 (Add)                     (None, 28, 28, 512)  0           bn3a_branch2c[0][0]              
                                                                 bn3a_branch1[0][0]               
__________________________________________________________________________________________________
activation_201 (Activation)     (None, 28, 28, 512)  0           add_4[0][0]                      
__________________________________________________________________________________________________
res3b_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_201[0][0]             
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_202 (Activation)     (None, 28, 28, 128)  0           bn3b_branch2a[0][0]              
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_202[0][0]             
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_203 (Activation)     (None, 28, 28, 128)  0           bn3b_branch2b[0][0]              
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_203[0][0]             
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3b_branch2c[0][0]             
__________________________________________________________________________________________________
add_5 (Add)                     (None, 28, 28, 512)  0           bn3b_branch2c[0][0]              
                                                                 activation_201[0][0]             
__________________________________________________________________________________________________
activation_204 (Activation)     (None, 28, 28, 512)  0           add_5[0][0]                      
__________________________________________________________________________________________________
res3c_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_204[0][0]             
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_205 (Activation)     (None, 28, 28, 128)  0           bn3c_branch2a[0][0]              
__________________________________________________________________________________________________
res3c_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_205[0][0]             
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_206 (Activation)     (None, 28, 28, 128)  0           bn3c_branch2b[0][0]              
__________________________________________________________________________________________________
res3c_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_206[0][0]             
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3c_branch2c[0][0]             
__________________________________________________________________________________________________
add_6 (Add)                     (None, 28, 28, 512)  0           bn3c_branch2c[0][0]              
                                                                 activation_204[0][0]             
__________________________________________________________________________________________________
activation_207 (Activation)     (None, 28, 28, 512)  0           add_6[0][0]                      
__________________________________________________________________________________________________
res3d_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_207[0][0]             
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_208 (Activation)     (None, 28, 28, 128)  0           bn3d_branch2a[0][0]              
__________________________________________________________________________________________________
res3d_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_208[0][0]             
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_209 (Activation)     (None, 28, 28, 128)  0           bn3d_branch2b[0][0]              
__________________________________________________________________________________________________
res3d_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_209[0][0]             
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3d_branch2c[0][0]             
__________________________________________________________________________________________________
add_7 (Add)                     (None, 28, 28, 512)  0           bn3d_branch2c[0][0]              
                                                                 activation_207[0][0]             
__________________________________________________________________________________________________
activation_210 (Activation)     (None, 28, 28, 512)  0           add_7[0][0]                      
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, 14, 14, 256)  131328      activation_210[0][0]             
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_211 (Activation)     (None, 14, 14, 256)  0           bn4a_branch2a[0][0]              
__________________________________________________________________________________________________
res4a_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_211[0][0]             
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_212 (Activation)     (None, 14, 14, 256)  0           bn4a_branch2b[0][0]              
__________________________________________________________________________________________________
res4a_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_212[0][0]             
__________________________________________________________________________________________________
res4a_branch1 (Conv2D)          (None, 14, 14, 1024) 525312      activation_210[0][0]             
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4a_branch2c[0][0]             
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 14, 14, 1024) 4096        res4a_branch1[0][0]              
__________________________________________________________________________________________________
add_8 (Add)                     (None, 14, 14, 1024) 0           bn4a_branch2c[0][0]              
                                                                 bn4a_branch1[0][0]               
__________________________________________________________________________________________________
activation_213 (Activation)     (None, 14, 14, 1024) 0           add_8[0][0]                      
__________________________________________________________________________________________________
res4b_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_213[0][0]             
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_214 (Activation)     (None, 14, 14, 256)  0           bn4b_branch2a[0][0]              
__________________________________________________________________________________________________
res4b_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_214[0][0]             
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_215 (Activation)     (None, 14, 14, 256)  0           bn4b_branch2b[0][0]              
__________________________________________________________________________________________________
res4b_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_215[0][0]             
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4b_branch2c[0][0]             
__________________________________________________________________________________________________
add_9 (Add)                     (None, 14, 14, 1024) 0           bn4b_branch2c[0][0]              
                                                                 activation_213[0][0]             
__________________________________________________________________________________________________
activation_216 (Activation)     (None, 14, 14, 1024) 0           add_9[0][0]                      
__________________________________________________________________________________________________
res4c_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_216[0][0]             
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_217 (Activation)     (None, 14, 14, 256)  0           bn4c_branch2a[0][0]              
__________________________________________________________________________________________________
res4c_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_217[0][0]             
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_218 (Activation)     (None, 14, 14, 256)  0           bn4c_branch2b[0][0]              
__________________________________________________________________________________________________
res4c_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_218[0][0]             
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4c_branch2c[0][0]             
__________________________________________________________________________________________________
add_10 (Add)                    (None, 14, 14, 1024) 0           bn4c_branch2c[0][0]              
                                                                 activation_216[0][0]             
__________________________________________________________________________________________________
activation_219 (Activation)     (None, 14, 14, 1024) 0           add_10[0][0]                     
__________________________________________________________________________________________________
res4d_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_219[0][0]             
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_220 (Activation)     (None, 14, 14, 256)  0           bn4d_branch2a[0][0]              
__________________________________________________________________________________________________
res4d_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_220[0][0]             
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_221 (Activation)     (None, 14, 14, 256)  0           bn4d_branch2b[0][0]              
__________________________________________________________________________________________________
res4d_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_221[0][0]             
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4d_branch2c[0][0]             
__________________________________________________________________________________________________
add_11 (Add)                    (None, 14, 14, 1024) 0           bn4d_branch2c[0][0]              
                                                                 activation_219[0][0]             
__________________________________________________________________________________________________
activation_222 (Activation)     (None, 14, 14, 1024) 0           add_11[0][0]                     
__________________________________________________________________________________________________
res4e_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_222[0][0]             
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4e_branch2a[0][0]             
__________________________________________________________________________________________________
activation_223 (Activation)     (None, 14, 14, 256)  0           bn4e_branch2a[0][0]              
__________________________________________________________________________________________________
res4e_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_223[0][0]             
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4e_branch2b[0][0]             
__________________________________________________________________________________________________
activation_224 (Activation)     (None, 14, 14, 256)  0           bn4e_branch2b[0][0]              
__________________________________________________________________________________________________
res4e_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_224[0][0]             
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4e_branch2c[0][0]             
__________________________________________________________________________________________________
add_12 (Add)                    (None, 14, 14, 1024) 0           bn4e_branch2c[0][0]              
                                                                 activation_222[0][0]             
__________________________________________________________________________________________________
activation_225 (Activation)     (None, 14, 14, 1024) 0           add_12[0][0]                     
__________________________________________________________________________________________________
res4f_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_225[0][0]             
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4f_branch2a[0][0]             
__________________________________________________________________________________________________
activation_226 (Activation)     (None, 14, 14, 256)  0           bn4f_branch2a[0][0]              
__________________________________________________________________________________________________
res4f_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_226[0][0]             
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4f_branch2b[0][0]             
__________________________________________________________________________________________________
activation_227 (Activation)     (None, 14, 14, 256)  0           bn4f_branch2b[0][0]              
__________________________________________________________________________________________________
res4f_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_227[0][0]             
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4f_branch2c[0][0]             
__________________________________________________________________________________________________
add_13 (Add)                    (None, 14, 14, 1024) 0           bn4f_branch2c[0][0]              
                                                                 activation_225[0][0]             
__________________________________________________________________________________________________
activation_228 (Activation)     (None, 14, 14, 1024) 0           add_13[0][0]                     
__________________________________________________________________________________________________
res5a_branch2a (Conv2D)         (None, 7, 7, 512)    524800      activation_228[0][0]             
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_229 (Activation)     (None, 7, 7, 512)    0           bn5a_branch2a[0][0]              
__________________________________________________________________________________________________
res5a_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_229[0][0]             
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_230 (Activation)     (None, 7, 7, 512)    0           bn5a_branch2b[0][0]              
__________________________________________________________________________________________________
res5a_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_230[0][0]             
__________________________________________________________________________________________________
res5a_branch1 (Conv2D)          (None, 7, 7, 2048)   2099200     activation_228[0][0]             
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5a_branch2c[0][0]             
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 7, 7, 2048)   8192        res5a_branch1[0][0]              
__________________________________________________________________________________________________
add_14 (Add)                    (None, 7, 7, 2048)   0           bn5a_branch2c[0][0]              
                                                                 bn5a_branch1[0][0]               
__________________________________________________________________________________________________
activation_231 (Activation)     (None, 7, 7, 2048)   0           add_14[0][0]                     
__________________________________________________________________________________________________
res5b_branch2a (Conv2D)         (None, 7, 7, 512)    1049088     activation_231[0][0]             
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_232 (Activation)     (None, 7, 7, 512)    0           bn5b_branch2a[0][0]              
__________________________________________________________________________________________________
res5b_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_232[0][0]             
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_233 (Activation)     (None, 7, 7, 512)    0           bn5b_branch2b[0][0]              
__________________________________________________________________________________________________
res5b_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_233[0][0]             
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5b_branch2c[0][0]             
__________________________________________________________________________________________________
add_15 (Add)                    (None, 7, 7, 2048)   0           bn5b_branch2c[0][0]              
                                                                 activation_231[0][0]             
__________________________________________________________________________________________________
activation_234 (Activation)     (None, 7, 7, 2048)   0           add_15[0][0]                     
__________________________________________________________________________________________________
res5c_branch2a (Conv2D)         (None, 7, 7, 512)    1049088     activation_234[0][0]             
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_235 (Activation)     (None, 7, 7, 512)    0           bn5c_branch2a[0][0]              
__________________________________________________________________________________________________
res5c_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_235[0][0]             
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_236 (Activation)     (None, 7, 7, 512)    0           bn5c_branch2b[0][0]              
__________________________________________________________________________________________________
res5c_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_236[0][0]             
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5c_branch2c[0][0]             
__________________________________________________________________________________________________
add_16 (Add)                    (None, 7, 7, 2048)   0           bn5c_branch2c[0][0]              
                                                                 activation_234[0][0]             
__________________________________________________________________________________________________
activation_237 (Activation)     (None, 7, 7, 2048)   0           add_16[0][0]                     
__________________________________________________________________________________________________
sequential_9 (Sequential)       (None, 17)           102778897   activation_237[0][0]             
==================================================================================================
Total params: 126,366,609
Trainable params: 126,313,489
Non-trainable params: 53,120
__________________________________________________________________________________________________

1.3编译网络


from keras.callbacks import ModelCheckpoint   

# train the model
checkpointer = ModelCheckpoint(filepath='flowers.weights.best.hdf5', verbose=1, 
                               save_best_only=True)

resnet_flower_model.compile(optimizer='rmsprop',loss='categorical_crossentropy', metrics=['accuracy'])

1.4训练网络

#model.fit(X_train,y_train,validation_split=0.2,shuffle=True,epochs=20)
history_object = resnet_flower_model.fit_generator(train_generator,validation_data=test_generator, epochs=EPOCHS)
Epoch 1/50
75/75 [==============================] - 303s 4s/step - loss: 7.7597 - accuracy: 0.2034 - val_loss: 85776.1953 - val_accuracy: 0.0588
Epoch 2/50
75/75 [==============================] - 297s 4s/step - loss: 3.2533 - accuracy: 0.2672 - val_loss: 1551.2539 - val_accuracy: 0.0588
Epoch 3/50
75/75 [==============================] - 300s 4s/step - loss: 2.7691 - accuracy: 0.3202 - val_loss: 460.1471 - val_accuracy: 0.0588
Epoch 4/50
75/75 [==============================] - 299s 4s/step - loss: 2.2212 - accuracy: 0.3748 - val_loss: 6.4425 - val_accuracy: 0.1647
Epoch 5/50
75/75 [==============================] - 299s 4s/step - loss: 2.0888 - accuracy: 0.4277 - val_loss: 6.9810 - val_accuracy: 0.2412
Epoch 6/50
75/75 [==============================] - 300s 4s/step - loss: 1.8929 - accuracy: 0.4672 - val_loss: 71.1725 - val_accuracy: 0.1294
Epoch 7/50
75/75 [==============================] - 299s 4s/step - loss: 1.7380 - accuracy: 0.4815 - val_loss: 22.3084 - val_accuracy: 0.1529
Epoch 8/50
75/75 [==============================] - 301s 4s/step - loss: 1.6882 - accuracy: 0.5235 - val_loss: 6.6994 - val_accuracy: 0.4706
Epoch 9/50
75/75 [==============================] - 301s 4s/step - loss: 1.5664 - accuracy: 0.5555 - val_loss: 3.4593 - val_accuracy: 0.4176
Epoch 10/50
56/75 [=====================>........] - ETA: 1:13 - loss: 1.4389 - accuracy: 0.5700


---------------------------------------------------------------------------

2.迁移学习分类主要取决于以下两个条件:

1.新数据集的大小,以及

2.新数据集与原始数据集的相似程度

使用迁移学习的方法将各不相同。有以下四大主要情形:

新数据集很小,新数据与原始数据相似

新数据集很小,新数据不同于原始训练数据

新数据集很大,新数据与原始训练数据相似

新数据集很大,新数据不同于原始训练数据

在这里插入图片描述

大型数据集可能具有 100 万张图片。小型数据集可能有 2000 张图片。大型数据集与小型数据集之间的界限比较主观。对小型数据集使用迁移学习需要考虑过拟合现象。

狗的图片和狼的图片可以视为相似的图片;这些图片具有共同的特征。鲜花图片数据集不同于狗类图片数据集。

四个迁移学习情形均具有自己的方法。

演示网络

为了解释每个情形的工作原理,我们将以一个普通的预先训练过的卷积神经网络开始,并解释如何针对每种情形调整该网络。我们的示例网络包含三个卷积层和三个完全连接层:

在这里插入图片描述

下面是卷积神经网络的作用一般概述:

第一层级将检测图片中的边缘

第二层级将检测形状

第三个卷积层将检测更高级的特征

每个迁移学习情形将以不同的方式使用预先训练过的神经网络。

2.1 情形 1:小数据集,相似数据

在这里插入图片描述

如果新数据集很小,并且与原始训练数据相似:

删除神经网络的最后层级

添加一个新的完全连接层,与新数据集中的类别数量相匹配

随机化设置新的完全连接层的权重;冻结预先训练过的网络中的所有权重

训练该网络以更新新连接层的权重;

为了避免小数据集出现过拟合现象,原始网络的权重将保持不变,而不是重新训练这些权重。

因为数据集比较相似,每个数据集的图片将具有相似的更高级别特征。因此,大部分或所有预先训练过的神经网络层级已经包含关于新数据集的相关信息,应该保持不变。

以下是如何可视化此方法的方式:

在这里插入图片描述

2.2 情形 2:小型数据集、不同的数据

在这里插入图片描述

如果新数据集很小,并且与原始训练数据不同:

将靠近网络开头的大部分预先训练过的层级删掉

向剩下的预先训练过的层级添加新的完全连接层,并与新数据集的类别数量相匹配

随机化设置新的完全连接层的权重;冻结预先训练过的网络中的所有权重

训练该网络以更新新连接层的权重

因为数据集很小,因此依然需要注意过拟合问题。要解决过拟合问题,原始神经网络的权重应该保持不变,就像第一种情况那样。

但是原始训练集和新的数据集并不具有相同的更高级特征。在这种情况下,新的网络仅使用包含更低级特征的层级。

以下是如何可视化此方法的方式:

在这里插入图片描述

2.3 情形 3:大型数据集、相似数据

在这里插入图片描述

如果新数据集比较大型,并且与原始训练数据相似:

删掉最后的完全连接层,并替换成与新数据集中的类别数量相匹配的层级

随机地初始化新的完全连接层的权重

使用预先训练过的权重初始化剩下的权重

重新训练整个神经网络

训练大型数据集时,过拟合问题不严重;因此,你可以重新训练所有权重。

因为原始训练集和新的数据集具有相同的更高级特征,因此使用整个神经网络。

以下是如何可视化此方法的方式:

在这里插入图片描述

2.4 情形 4:大型数据集、不同的数据

在这里插入图片描述

如果新数据集很大型,并且与原始训练数据不同:

删掉最后的完全连接层,并替换成与新数据集中的类别数量相匹配的层级

使用随机初始化的权重重新训练网络

或者,你可以采用和“大型相似数据”情形的同一策略

虽然数据集与训练数据不同,但是利用预先训练过的网络中的权重进行初始化可能使训练速度更快。因此这种情形与大型相似数据集这一情形完全相同。

如果使用预先训练过的网络作为起点不能生成成功的模型,另一种选择是随机地初始化卷积神经网络权重,并从头训练网络。

以下是如何可视化此方法的方式:
在这里插入图片描述

参考资料

参阅这篇 研究论文,该论文系统地分析了预先训练过的 CNN 中的特征的可迁移性。

https://arxiv.org/pdf/1411.1792.pdf

阅读这篇详细介绍 Sebastian Thrun 的癌症检测 CNN 的《自然》论文!

这是提议将 GAP 层级用于对象定位的首篇研究论文。
http://cnnlocalization.csail.mit.edu/Zhou_Learning_Deep_Features_CVPR_2016_paper.pdf

参阅这个使用 CNN 进行对象定位的资源库。

https://github.com/alexisbcook/ResNetCAM-keras

观看这个关于使用 CNN 进行对象定位的视频演示(Youtube链接,国内网络可能打不开)。

https://www.youtube.com/watch?v=fZvOy0VXWAI

参阅这个使用可视化机器更好地理解瓶颈特征的资源库。

https://github.com/alexisbcook/keras_transfer_cifar10

  • 4
    点赞
  • 34
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值