tensorflow 迁移学习_基于Keras使用TensorFlow Hub实现迁移学习(TF2官方教程翻译)

最新版本:

http://www.mashangxue123.com/tensorflow/tf2-tutorials-images-hub_with_keras.html

英文版本:

https://tensorflow.google.cn/alpha/tutorials/images/hub_with_keras

TensorFlow Hub是一种共享预训练模型组件的方法。

TensorFlow Hub是一个用于促进机器学习模型的可重用部分的发布,探索和使用的库。特别是,它提供经过预先训练的TensorFlow模型,可以在新任务中重复使用。(可以理解为做迁移学习:可以使用较小的数据集训练模型,可以改善泛化和加快训练。)GitHub 地址:https://github.com/tensorflow/hub

有关预先训练模型的可搜索列表,请参阅TensorFlow模块中心TensorFlow Module Hub。

本教程演示:

  1. 如何在tf.keras中使用TensorFlow Hub。
  2. 如何使用TensorFlow Hub进行图像分类。
  3. 如何做简单的迁移学习。
6400ebc9982a482ab2f26b8eb7589a7a

1. 安装和导入包

安装命令:pip install -U tensorflow_hub

from __future__ import absolute_import, division, print_function, unicode_literalsimport matplotlib.pylab as pltimport tensorflow as tf import tensorflow_hub as hubfrom tensorflow.keras import layers

2. ImageNet分类器

2.1. 下载分类器

使用hub.module加载mobilenet,并使用tf.keras.layers.Lambda将其包装为keras层。

来自tfhub.dev的任何兼容tf2的图像分类器URL都可以在这里工作。

classifier_url ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2" #@param {type:"string"}IMAGE_SHAPE = (224, 224)classifier = tf.keras.Sequential([ hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE+(3,))])

2.2. 在单个图像上运行它

下载单个图像以试用该模型。

import numpy as npimport PIL.Image as Imagegrace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)grace_hopper = np.array(grace_hopper)/255.0grace_hopper.shape

(224, 224, 3)

添加批量维度,并将图像传递给模型。

result = classifier.predict(grace_hopper[np.newaxis, ...])result.shape

结果是1001元素向量的logits,对图像属于每个类的概率进行评级。因此,可以使用argmax找到排在最前的类别ID:

predicted_class = np.argmax(result[0], axis=-1)predicted_class653

2.3. 解码预测

我们有预测的类别ID,获取ImageNet标签,并解码预测

labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')imagenet_labels = np.array(open(labels_path).read().splitlines())plt.imshow(grace_hopper)plt.axis('off')predicted_class_name = imagenet_labels[predicted_class]_ = plt.title("Prediction: " + predicted_class_name.title())
2a1f45d562bb475aa5fb3108d1efb8a9

3. 简单的迁移学习

使用TF Hub可以很容易地重新训练模型的顶层以识别数据集中的类。

3.1. Dataset

对于此示例,您将使用TensorFlow鲜花数据集:

data_root = tf.keras.utils.get_file( 'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', untar=True)

将此数据加载到我们的模型中的最简单方法是使用 tf.keras.preprocessing.image.ImageDataGenerator,

所有TensorFlow Hub的图像模块都期望浮点输入在“[0,1]”范围内。使用ImageDataGenerator的rescale参数来实现这一目的。图像大小将在稍后处理。

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)image_data = image_generator.flow_from_directory(str(data_root), target_size=IMAGE_SHAPE) Found 3670 images belonging to 5 classes.

结果对象是一个返回image_batch,label_batch对的迭代器。

for image_batch, label_batch in image_data: print("Image batch shape: 
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值