最全tensorflow2.0学习路线 https://www. mashangxue123.com
在本教程中,您将学习如何使用预训练网络进行转移学习对猫与狗图像分类。主要内容:使用预训练的模型进行特征提取,微调与训练的模型。
预训练模型是一个保存的网路,以前在大型数据集上训练的,通常是在大规模图像分类任务上,您可以按原样使用预训练模型,也可以使用转移学习将此模型自定义为给定的任务。
转移学习背后的直觉是,如果一个模型在一个大而且足够通用的数据集上训练,这个模型将有效地作为视觉世界的通用模型。然后,您可以利用这些学习的特征映射,而无需从头开始训练大型数据集上的大型模型。
在本节中,您将尝试两种方法来自定义预训练模型:
- 特征提取:使用先前网络学习的表示从新样本中提取有意义的特征,您只需在与训练模型的基础上添加一个新的分类器(将从头开始训练),以便您可以重新调整先前为我们的数据集学习的特征映射。 您不需要(重新)训练整个模型,基本卷积网络已经包含了一些对图片分类非常有用的特性。然而,预训练模型的最后一个分类部分是特定于原始分类任务的,然后是特定于模型所训练的一组类。
- 微调:解冻冻结模型的顶层,并共同训练新添加的分类器和基础模型的最后一层,这允许我们“微调”基础模型中的高阶特征表示,以使它们与特定任务更相关。
你将要遵循一般的机器学习工作流程:
- 检查并理解数据
- 构建输入管道,在本例中使用Keras 的
ImageDataGenerator
- 构建模型
- 加载我们的预训练基础模型(和预训练的权重)
- 将我们的分类图层堆叠在顶部
- 训练模型
- 评估模型
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
keras = tf.keras
1. 数据预处理
1.1. 下载数据
使用 TensorFlow Datasets加载猫狗数据集。tfds
包是加载预定义数据的最简单方法,如果您有自己的数据,并且有兴趣使用TensorFlow进行导入,请参阅加载图像数据。
import tensorflow_datasets as tfds
tfds.load
方法下载并缓存数据,并返回tf.data.Dataset
对象,这些对象提供了强大、高效的方法来处理数据并将其传递到模型中。
由于"cats_vs_dog"
没有定义标准分割,因此使用subsplit功能将其分为训练80%、验证10%、测试10%的数据。
SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)
(raw_train, raw_validation, raw_test), metadata = tfds.load(
'cats_vs_dogs', split=list(splits),
with_info=True, as_supervised=True)
生成的tf.data.Dataset
对象包含(图像,标签)对。图像具有可变形状和3个通道,标签是标量。
print(raw_train)
print(raw_validation)
print(raw_test)
<DatasetV1Adapter shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
<DatasetV1Adapter shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
<DatasetV1Adapter shapes: ((None, None, 3), ()), types: (tf.uint8, tf.int64)>
显示训练集中的前两个图像和标签:
get_label_name = metadata.features['label'].int2str
for image, label in raw_train.take(2):
plt.figure()
plt.imshow(image)
plt.title(get_label_name(label))
![83cf3434d59148a926fdc2bec35e0b67.png](https://i-blog.csdnimg.cn/blog_migrate/e0af8f6e08b753ca39b3af9b2c7d1359.png)
![16f60ffeae8d27af9d1a6eb1e5976a13.png](https://i-blog.csdnimg.cn/blog_migrate/5764e2178189e85fba4c99255a101114.jpeg)
1.2. 格式化数据
使用tf.image
模块格式化图像,将图像调整为固定的输入大小,并将输入通道重新调整为[-1,1]
范围。
IMG_SIZE = 160 # 所有图像将被调整为160x160
def format_example(image, label):
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
使用map方法将此函数应用于数据集中的每一个项:
train = raw_train.map(format_example)
validation = raw_validation.map(format