一个例子了解迁移学习

迁移学习

对于传统机器学习而言,要求训练样本与测试样本满足独立同分布,而且必须要有足够多的训练样本。而迁移学习能把一个领域(即源领域)的知识,迁移到另外一个领域(即目标领域),目标领域往往只有少量有标签样本,使得目标领域能够取得更好的学习效果。

迁移方式

  • 样本迁移,在源领域中找出与目标领域相似的样本,增加该样本的权重,使其在预测目标与的比重加大。
  • 特征迁移,源领域与目标领域包含共同的交叉特征,通过特征变换将源领域和目标领域的的特征变换到相同空间,使它们具有相同分布。
  • 模型迁移,源领域和目标领域共享模型参数,将源领域已训练好的网络模型应用到目标领域的新问题上。
  • 关系迁移,源领域和目标领域具有某种相似关系,可以将源领域的逻辑关系应用到目标领域中。

模型迁移

这里基于预训练的卷积神经网络训练一组新参数,然后将其用于分类任务,这样就能共享模型参数,避免了从头开始训练模型的参数,大大减少训练时间。

数据集

在示例中使用flower17数据集,它是一个包含17种花卉类别的数据集,每个类别有80张图像。收集的花都是英国一些常见的花,这些图像具有大比例、不同姿态和光线变化等性质。

使用水仙花和款冬这两类花,并且在预训练的VGG16网络之上构建分类器。

实现

首先导入所有必需的库,包括应用程序、预处理、模型检查点以及相关对象,cv2库和NumPy库用于图像处理和数值的基本操作。

from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.models import Model
from keras.layers import Dropout, Flatten, Dense
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.applications.vgg16 import preprocess_input
import cv2
import numpy as np
复制代码

定义输入、数据源及与训练参数相关的所有变量。

img_width, img_height = 224, 224
train_data_dir = "data/train"
validation_data_dir = "data/validation"
nb_train_samples = 300
nb_validation_samples = 100
batch_size = 16
epochs = 1
复制代码

调用VGG16预训练模型,其中不包括顶部的平整化层。冻结不参与训练的层,这里我们冻结前五层,然后添加自定义层,从而创建最终的模型。

model = applications.VGG16(weights="imagenet", include_top=False, input_shape=(img_width, img_height, 3))
for layer in model.layers[:5]:
    layer.trainable = False
x = model.output
x = Flatten()(x)
x = Dense(1024, activation="relu")(x)
x = Dropout(0.5)(x)
x = Dense(1024, activation="relu")(x)
predictions = Dense(2, activation="softmax")(x)
model_final = Model(inputs=model.input, output=predictions)
复制代码

接着开始编译模型,并为训练、测试数据集创建图像数据增强生成器。

model_final.compile(loss="categorical_crossentropy", optimizer=optimizers.SGD(lr=0.0001, momentum=0.9),
                    metrics=["accuracy"])
train_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, fill_mode="nearest", zoom_range=0.3,
                                   width_shift_range=0.3, height_shift_range=0.3, rotation_range=30)
test_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, fill_mode="nearest", zoom_range=0.3,
                                  width_shift_range=0.3, height_shift_range=0.3, rotation_range=30)
复制代码

生成增强后新的数据,根据情况保存模型。

train_generator = train_datagen.flow_from_directory(train_data_dir, target_size=(img_height, img_width),
                                                    batch_size=batch_size, class_mode="categorical")
validation_generator = test_datagen.flow_from_directory(validation_data_dir, target_size=(img_height, img_width),
                                                        class_mode="categorical")
checkpoint = ModelCheckpoint("vgg16_1.h5", monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False,
                             mode='auto', period=1)
early = EarlyStopping(monitor='val_acc', min_delta=0, patience=10, verbose=1, mode='auto')
复制代码

开始对模型中新的网络层进行拟合。

model_final.fit_generator(train_generator, samples_per_epoch=nb_train_samples, nb_epoch=epochs,
                          validation_data=validation_generator, nb_val_samples=nb_validation_samples,
                          callbacks=[checkpoint, early])
复制代码

练完成后用水仙花图像测试这个新模型,输出的正确值应该为接近[1.,0.]的数组。

im = cv2.resize(cv2.imread('data/test/gaff2.jpg'), (img_width, img_height))
im = np.expand_dims(im, axis=0).astype(np.float32)
im = preprocess_input(im)
out = model_final.predict(im)
print(out)
print(np.argmax(out))
复制代码
 1/18 [>.............................] - ETA: 16:43 - loss: 0.9380 - acc: 0.3750
 2/18 [==>...........................] - ETA: 13:51 - loss: 0.8720 - acc: 0.4062
 3/18 [====>.........................] - ETA: 12:32 - loss: 0.8382 - acc: 0.4167
 4/18 [=====>........................] - ETA: 10:53 - loss: 0.8103 - acc: 0.4663
 5/18 [=======>......................] - ETA: 10:00 - loss: 0.8208 - acc: 0.4606
 6/18 [=========>....................] - ETA: 9:12 - loss: 0.8083 - acc: 0.4567 
 7/18 [==========>...................] - ETA: 8:24 - loss: 0.7891 - acc: 0.4718
 8/18 [============>.................] - ETA: 7:37 - loss: 0.7994 - acc: 0.4832
 9/18 [==============>...............] - ETA: 6:51 - loss: 0.7841 - acc: 0.4850Epoch 00001: val_acc improved from -inf to 0.40000, saving model to vgg16_1.h5

 9/18 [==============>...............] - ETA: 7:16 - loss: 0.7841 - acc: 0.4850 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00[[0.2213877  0.77861226]]
复制代码

github

github.com/sea-boat/De…

-------------推荐阅读------------

我的开源项目汇总(机器&深度学习、NLP、网络IO、AIML、mysql协议、chatbot)

为什么写《Tomcat内核设计剖析》

我的2017文章汇总——机器学习篇

我的2017文章汇总——Java及中间件

我的2017文章汇总——深度学习篇

我的2017文章汇总——JDK源码篇

我的2017文章汇总——自然语言处理篇

我的2017文章汇总——Java并发篇


跟我交流,向我提问:

欢迎关注:

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值