最新版本:
http://www.mashangxue123.com/tensorflow/tf2-tutorials-images-hub_with_keras.html
英文版本:
https://tensorflow.google.cn/alpha/tutorials/images/hub_with_keras
TensorFlow Hub是一种共享预训练模型组件的方法。
TensorFlow Hub是一个用于促进机器学习模型的可重用部分的发布,探索和使用的库。特别是,它提供经过预先训练的TensorFlow模型,可以在新任务中重复使用。(可以理解为做迁移学习:可以使用较小的数据集训练模型,可以改善泛化和加快训练。)GitHub 地址:https://github.com/tensorflow/hub
有关预先训练模型的可搜索列表,请参阅TensorFlow模块中心TensorFlow Module Hub。
本教程演示:
- 如何在tf.keras中使用TensorFlow Hub。
- 如何使用TensorFlow Hub进行图像分类。
- 如何做简单的迁移学习。
1. 安装和导入包
安装命令:pip install -U tensorflow_hub
from __future__ import absolute_import, division, print_function, unicode_literalsimport matplotlib.pylab as pltimport tensorflow as tf import tensorflow_hub as hubfrom tensorflow.keras import layers
2. ImageNet分类器
2.1. 下载分类器
使用hub.module加载mobilenet,并使用tf.keras.layers.Lambda将其包装为keras层。
来自tfhub.dev的任何兼容tf2的图像分类器URL都可以在这里工作。
classifier_url ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2" #@param {type:"string"}IMAGE_SHAPE = (224, 224)classifier = tf.keras.Sequential([ hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE+(3,))])
2.2. 在单个图像上运行它
下载单个图像以试用该模型。
import numpy as npimport PIL.Image as Imagegrace_hopper = tf.keras.utils.get_file('image.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg')grace_hopper = Image.open(grace_hopper).resize(IMAGE_SHAPE)grace_hopper = np.array(grace_hopper)/255.0grace_hopper.shape
(224, 224, 3)
添加批量维度,并将图像传递给模型。
result = classifier.predict(grace_hopper[np.newaxis, ...])result.shape
结果是1001元素向量的logits,对图像属于每个类的概率进行评级。因此,可以使用argmax找到排在最前的类别ID:
predicted_class = np.argmax(result[0], axis=-1)predicted_class653
2.3. 解码预测
我们有预测的类别ID,获取ImageNet标签,并解码预测
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')imagenet_labels = np.array(open(labels_path).read().splitlines())plt.imshow(grace_hopper)plt.axis('off')predicted_class_name = imagenet_labels[predicted_class]_ = plt.title("Prediction: " + predicted_class_name.title())
3. 简单的迁移学习
使用TF Hub可以很容易地重新训练模型的顶层以识别数据集中的类。
3.1. Dataset
对于此示例,您将使用TensorFlow鲜花数据集:
data_root = tf.keras.utils.get_file( 'flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', untar=True)
将此数据加载到我们的模型中的最简单方法是使用 tf.keras.preprocessing.image.ImageDataGenerator,
所有TensorFlow Hub的图像模块都期望浮点输入在“[0,1]”范围内。使用ImageDataGenerator的rescale参数来实现这一目的。图像大小将在稍后处理。
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)image_data = image_generator.flow_from_directory(str(data_root), target_size=IMAGE_SHAPE) Found 3670 images belonging to 5 classes.
结果对象是一个返回image_batch,label_batch对的迭代器。
for image_batch, label_batch in image_data: print("Image batch shape: