tensorflow 迁移学习_TensorFlow2学习十一、TF-Hub实现迁移学习

一、概念

1. TF-Hub介绍

Tensorflow-hub 是 google 提供的可以共享学习的打包函式库,帮开发者把TensorFlow的训练模型发布成模组,方便再次使用或是与社交共享。

2. 迁移学习

迁移学习是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中。深度学习中在计算机视觉任务和自然语言处理任务中将预训练的模型作为新模型的起点是一种常用的方法,通常这些预训练的模型在开发神经网络的时候已经消耗了巨大的时间资源和计算资源,迁移学习可以将已习得的强大技能迁移到相关的的问题上。——百度百科

本文示例来自google tensorflow官网,主要演示以下3方面操作:

  1. tf.keras使用tensorflow hub
  2. 使用tf hub实现图片分类
  3. 进行简单的迁移学习

二、加载一个图像分类器

1. 导入包

from __future__ import absolute_import, division, print_function, unicode_literalsimport matplotlib.pylab as pltimport tensorflow as tf!pip install -q -U tf-hub-nightlyimport tensorflow_hub as hubfrom tensorflow.keras import layers

2. 一个图片网络分类器

下载分类器

classifier_url ="https://hub.tensorflow.google.cn/google/tf2-preview/mobilenet_v2/classification/2" #@param {type:"string"}IMAGE_SHAPE = (224, 224)classifier = tf.keras.Sequential([ hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE+(3,))])
b9736392872a49148ef9d8522f8031c2

识别一个图片

import numpy as npimport PIL.Image as Imagegrace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)grace_hopper = np.array(grace_hopper)/255.0result = classifier.predict(grace_hopper[np.newaxis, ...])predicted_class = np.argmax(result[0], axis=-1)# 加标注labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')imagenet_labels = np.array(open(labels_path).read().splitlines())# 可视化plt.imshow(grace_hopper)plt.axis('off')predicted_class_name = imagenet_labels[predicted_class]_ = plt.title("Prediction: " + predicted_class_name.title())
dbc9904273794b70a7ee1fb514f6c2b1

三、简单的迁移学习

1. 直接预测花数据集

# 下载花数据集data_root = tf.keras.utils.get_file( 'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', untar=True)# 使用ImageDataGenerator's rescale 把数据转成tf hub需要的格式(值范围都在[0,1]之间)image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)image_data = image_generator.flow_from_directory(str(data_root), target_size=IMAGE_SHAPE)for image_batch, label_batch in image_data: print("Image batch shape: ", image_batch.shape) print("Label batch shape: ", label_batch.shape) break# 上面输出# Image batch shape: (32, 224, 224, 3)# Label batch shape: (32, 5)# 试试预测result_batch = classifier.predict(image_batch)predicted_class_names = imagenet_labels[np.argmax(result_batch, axis=-1)]plt.figure(figsize=(10,9))plt.subplots_adjust(hspace=0.5)for n in range(30): plt.subplot(6,5,n+1) plt.imshow(image_batch[n]) plt.title(predicted_class_names[n]) plt.axis('off')_ = plt.suptitle("ImageNet predictions")
14a4fc922ce44a09bc28fd9d37debe0c


这个结果显示不够好。

2. 下载模型进行修改

feature_extractor_url = "https://hub.tensorflow.google.cn/google/tf2-preview/mobilenet_v2/feature_vector/2" #@param {type:"string"}feature_extractor_layer = hub.KerasLayer(feature_extractor_url, input_shape=(224,224,3))feature_batch = feature_extractor_layer(image_batch)print(feature_batch.shape) # (32, 1280)feature_extractor_layer.trainable = False# 重新创建模型model = tf.keras.Sequential([ feature_extractor_layer, layers.Dense(image_data.num_classes, activation='softmax')])# 预测器predictions = model(image_batch)# 编译模型model.compile( optimizer=tf.keras.optimizers.Adam(), loss='categorical_crossentropy', metrics=['acc'])# 训练class CollectBatchStats(tf.keras.callbacks.Callback): def __init__(self): self.batch_losses = [] self.batch_acc = [] def on_train_batch_end(self, batch, logs=None): self.batch_losses.append(logs['loss']) self.batch_acc.append(logs['acc']) self.model.reset_metrics()steps_per_epoch = np.ceil(image_data.samples/image_data.batch_size)batch_stats_callback = CollectBatchStats()history = model.fit_generator(image_data, epochs=2, steps_per_epoch=steps_per_epoch, callbacks = [batch_stats_callback])# 显示训练时损失值变化plt.figure()plt.ylabel("Loss")plt.xlabel("Training Steps")plt.ylim([0,2])plt.plot(batch_stats_callback.batch_losses)# 准确率变化情况plt.figure()plt.ylabel("Accuracy")plt.xlabel("Training Steps")plt.ylim([0,1])plt.plot(batch_stats_callback.batch_acc)
0b6ceb79102845d5a81e999e11fe7767
754bcd1d0c6e458383f056699f68b707
# 测试结果class_names = sorted(image_data.class_indices.items(), key=lambda pair:pair[1])class_names = np.array([key.title() for key, value in class_names])predicted_batch = model.predict(image_batch)predicted_id = np.argmax(predicted_batch, axis=-1)predicted_label_batch = class_names[predicted_id]label_id = np.argmax(label_batch, axis=-1)# 可视化plt.figure(figsize=(10,9))plt.subplots_adjust(hspace=0.5)for n in range(30): plt.subplot(6,5,n+1) plt.imshow(image_batch[n]) color = "green" if predicted_id[n] == label_id[n] else "red" plt.title(predicted_label_batch[n].title(), color=color) plt.axis('off')_ = plt.suptitle("Model predictions (green: correct, red: incorrect)")
f6f1bf290d4643e2bd0600d3341d335f

四、导出模型

import timet = time.time()export_path = "/tmp/saved_models/{}".format(int(t))model.save(export_path, save_format='tf')export_path# 装载reloaded = tf.keras.models.load_model(export_path)result_batch = model.predict(image_batch)reloaded_result_batch = reloaded.predict(image_batch)abs(reloaded_result_batch - result_batch).max()

模型可以被再次装载引用,或装成TFLite 、 TFjs格式。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值