这段时间在系统学习tensorflow的相关知识,恰好学习到了tensorflow的slim轻量级开发库。这个库的目的在于用尽量少的成本组织起来一套可以训练和测试自己的分类任务的代码,其中涉及到了迁移学习,所以我们分为下面几个步骤介绍:
什么是迁移学习;
什么是TF-Slim;
TF-Slim实现迁移学习的例程;
应用自己的数据集完成迁移学习。
什么是迁移学习:
一般在初始化CNN的卷积核时,使用的是正态随机初始化,此时训练这个网络的话就是在从头训练,然而既然反正都要初始化核参数,那么为什么不干脆拿一个在其他任务中训练好的参数进行初始化呢?一般认为如果一个网络在某个更为复杂的任务上表现优异的话(这需要大量的数据与长时间的训练),那么它的参数是具有比较好的特征抽取能力的,又因为CNN的前几层提取的一般为较低级的特征(边缘,轮廓等),所以这些参数即使换一个任务的话,也会有不错的效果(起码在前几层是这样,而且起码比正态随机初始化要好)。在一个数据量比较大的任务中完成训练的过程就是pre-train,用pre-train的参数初始化一个新的网络,并对这些参数再次训练(微调),使之适用于新任务的过程就是fine-tune。一般情况下,我们会选择ImageNet数据集上训练好的网络,因为它经过大数据量与长时间的训练。很多已有的imagenet图像数据覆盖面已经很广,基于此训练的网络权重已经能提取大致的目标特征(从低微像素到高维的结构特征),所以可使用fine-tune只训练框架中某些层的权重,当然根据自己数据集做全部权重重新训练的检测效果理论会更好些,需要权衡时间成本和检测精度的需求了;好在TensorFlow已经提供了各种pre-train model:
然后我们举个例子说下Google是怎么训练这些模型,在ImageNet数据集上,用128GB内存+8个NVIDIA Tesla K40 GPU训练Inception网络,耗时100个小时,Top1达到69.8%。
什么是TF-Slim
TF-slim是用于定义,训练和评估复杂模型的TensorFlow(tensorflow.contrib.slim)的新型轻量级高级API。 可以把它理解为TensorFlow提供的一种更高级的封装吧,其实它和迁移学习没什么关系,只是在后面的内容中会用到,所以在这里提一下。具体内容可以点击这里查看,其翻译版可点击这里。
上面这个链接目录下主要包含:
official models(这个是用Tensorflow高层API做的例子模型集,建议初学者可尝试);
research models(这个是很多研究者利用tensorflow做的模型集,这个不是官方提供的,是研究者个人在维护的);
samples folder (包含代码片段和小的模型用以表述tensorflow特性,包含以博客形式存在的代码呈现);
而我说的slim工具包就在research文件夹下。它出现在models/research/slim/。
使用某网络结构进行全训练
1.从头训练
训练模型的代码被放在slim下的train_image_classfier.py文件里,这里用flower数据集来训练Inception_v3网络结构的深度神经网络模型。
对于windows环境
python train_image_classifier.py --train_dir=D:\\pythonWorkSpace\\use_slim\\logs\\in3flowers --dataset_name=flowers --dataset_split_name=train --dataset_dir=D:\\pythonWorkSpace\\use_slim\\flowers_tf_records --model_name=inception_v3
对于linux环境
python train_image_classifier.py --train_dir=/public/home/xxbai/PycharmProjects/use_slim/logs/in3flowers --dataset_name=flowers --dataset_split_name=train --dataset_dir=/public/home/xxbai/PycharmProjects/use_slim/flowers_tf_records --model_name=inception_v3
为了方便在不同的linux机器上运行,这里将工作路径抽取出来定义成一个变量。
WORK_DIR=/public/home/xxbai/PycharmProjects
python train_image_classifier.py --train_dir=${WORK_DIR}/use_slim/logs/in3flowers --dataset_name=flowers --dataset_split_name=train --dataset_dir=${WORK_DIR}/use_slim/flowers_tf_records --model_name=inception_v3
参数说明:
train_dir:要生成模型的路径。
dataset_name:数据集名字。
dataset_split_name:数据集的哪一部分是validation还是train。
dataset_dir:数据集路径。
model_name:模型名字。
2 预训练模型
预训练模型可以在前面所述的网址中下载对应网络的ImageNet模型,不过预训练时使用的样本必须与原来的输入尺寸和输出的分类个数一致。但是ImageNet模型上都是分成1000/1001类,显然一般不具有通用性,因此可以用微调方法解决。
对于预训练方法,只需要在从头训练的命令中添加一个参数 --checkpoint_path
--checkpoint_path=模型的路径
3 微调fine-tuning
预训练的模型都是在imagenet上实现的,最终的输出都是1000个类别,如果我们想使用在自己的数据集上面就需要微调了。
在微调的过程中,需要将原有模型的最后一层去掉,转换成自己的数据集对应的分类层。例如我们要训练flowers数据集,就需要将1000个输出换成10个输出。
具体做法:
1)通过参数--checkpoing_exclude_scopes 指定载入预训练模型时哪一层的权重不会被载入。
2)通过参数--trainable_scopes参数指定对哪一层的参数进行训练。当--trainable_scopes出现时,没有被指定训练的参数将在训练中被冻结。
对于linux环境
python train_image_classifier.py
--train_dir=/public/home/xxbai/PycharmProjects/use_slim/logs/in3flowers
--dataset_name=flowers
--dataset_split_name=train
--dataset_dir=/public/home/xxbai/PycharmProjects/use_slim/flowers_tf_records
--model_name=inception_v3
--checkpoint_path=/public/home/xxbai/PycharmProjects/use_slim/pretrained_models/inception_v3.ckpt
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
--max_number_of_steps=1000
环境迁移运行代码
WOEK_DIR=/public/home/xxbai/PycharmProjects
python train_image_classifier.py
--train_dir=${WOEK_DIR}/use_slim/logs/in3flowers
--dataset_name=flowers
--dataset_split_name=train
--dataset_dir=${WOEK_DIR}/use_slim/flowers_tf_records
--model_name=inception_v3
--checkpoint_path=${WOEK_DIR}/use_slim/pretrained_models/inception_v3.ckpt
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
--max_number_of_steps=1000
windows按照对应的路径更改
4 模型评估
eval_image_classifier.py文件是已经封装好用来进行模型评估的,这里的评估结果是对从头训练做出的评估。
对于linux环境
python eval_image_classifier.py --checkpoint_path=/public/home/xxbai/PycharmProjects/use_slim/logs/in3flowers/model.ckpt-2438 --eval_dir=logs/in3flowers/model.ckpt-2438 --dataset_name=flowers --dataset_split_name=validation --dataset_dir=/public/home/xxbai/PycharmProjects/use_slim/flowers_tf_records --model_name=inception_v3
环境迁移运行代码
WOEK_DIR=/public/home/xxbai/PycharmProjects
python eval_image_classifier.py --checkpoint_path=${WOEK_DIR}/use_slim/logs/in3flowers/model.ckpt-2438 --eval_dir=logs/in3flowers/model.ckpt-2438 --dataset_name=flowers --dataset_split_name=validation --dataset_dir=${WOEK_DIR}/use_slim/flowers_tf_records --model_name=inception_v3
其中的2438代表模型迭代的次数。
可以看到重新训练在迭代2438次后的准确率只有0.1625.
TF-Slim实现迁移学习的例程
在TensorFlow的github网址中提供了一个包含了数据准备+训练+预测的例程—Flowers,它只需我们运行几个脚本或命令行,不需要该任何代码就可以,我们先把这个例程解释一下:
1.准备工作:
首先我们需要再https://github.com/tensorflow/models把TensorFlow-models下载下来,放在本地一个位置上,比如根目录。
2.转化TFRecord文件:
TFRecord文件是一种TensorFlow提供的数据格式,它可以将图片二进制数据和图片其他数据(如标签,尺寸等等)存储在同一个文件中,有种格式更加利于TensorFlow的读取机制。所以我们需要先生成Flowers数据集的TFRecord文件。
TensorFlow-models内提供了一个download_and_convert_data.py文件,我们可以利用这个代码完成数据准备工作,但是在此之前,建议把download_and_convert_flowers.py文件中的210行代码注释掉,这样一来解压缩之后的原始数据就可以留下来了,这样方便我们查看。
然后我们就可以运行这个文件了,注意一下我们要运行的是download_and_convert_data.py文件,要修改的是download_and_convert_flowers.py文件。因为我的系统是Windows,所以在这里我就直接使用命令行了,使用Linux的同学可以直接运行.sh文件,我们只需要进入slim后执行:
python download_and_convert_data.py --dataset_name=flowers --dataset_dir=/pythonWorkSpace/use_slim/flowers_tf_records
其中flowers_tf_records是文件夹的名字,代码将在该文件加内下载flowers数据集的压缩包,解压后生产TFRecord文件,压缩包大小大概有200多M的样子吧。
下载完成之后,代码会随机的抽取350张图片组成验证集,剩下的3320张组成训练集,并分别打成5个TFRecord文件。
再回到floewers_5文件夹中,我们就可以看到下面这些东西,一个压缩文件,一个解压缩之后的文件夹,10个TFRecord文件和一个labels文件。
3.迁移inception_v3训练新任务
数据集准备完成后,我们就可以进行训练,这里使用TF提供的inception_v3网络,首先我们需要在上面提到的那个图里下载下来inception_v3模型文件解压缩,我放在了D:\pythonWorkSpace\use_slim\pretrained_models下。
然后我们可以直接执行train_image_classifier.py文件:
在windows环境下运行:
python train_image_classifier.py
--dataset_name=flowers
--dataset_dir=D:\\pythonWorkSpace\\use_slim\\flowers_tf_records
--checkpoint_path=D:\\pythonWorkSpace\\use_slim\\pretrained_models\\inception_resnet_v2_2016_08_30.ckpt
--model_name=inception_resnet_v2
--checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits/logits
--trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits/logits
--train_dir=D:\\pythonWorkSpace\\use_slim\\use_slim\\logs\\in3flowers
--learning_rate=0.001
--learning_rate_decay_factor=0.76
--num_epochs_per_decay=50
--moving_average_decay=0.9999
--optimizer=adam
--ignore_missing_vars=True
--batch_size=32
在linux环境下运行:
WORK_DIR=/public/home/xxbai/PycharmProjects
python train_image_classifier.py --dataset_name=flowers --dataset_dir=${WORK_DIR}/use_slim/flowers_tf_records --checkpoint_path=${WORK_DIR}/use_slim/pretrained_models/inception_v3.ckpt --model_name=inception_v3 --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits --train_dir=${WORK_DIR}/use_slim/logs/in3flowers --learning_rate=0.001 --learning_rate_decay_factor=0.76 --num_epochs_per_decay=50 --moving_average_decay=0.9999 --optimizer=adam --ignore_missing_vars=True --batch_size=32 --max_number_of_steps=10000
运行结果:
4.准确率验证
短暂的训练之后,我们就可以测试下验证集上的准确率了,执行eval_image_classifier.py文件:
WORK_DIR=/public/home/xxbai/PycharmProjects
python eval_image_classifier.py --dataset_name=flowers --dataset_dir=${WORK_DIR}/use_slim/flowers_tf_records --dataset_split_name=validation --model_name=inception_v3 --checkpoint_path=${WORK_DIR}/use_slim/logs/in3flowers/model.ckpt-10000 --eval_dir=${WORK_DIR}/use_slim/logs/in3flowers/model.ckpt-10000 --batch_size=32
可以看到,一个5分类数据集经过10000次训练后,准确率有0.889,回召率是1。
TensorBoard
为了在训练期间损失和其他指标可视化,可以通过运行以下命令使用TensorBoard :
tensorboard --logdir=${TRAIN_DIR}
一旦TensorBoard开始运行,即可在浏览器中打开http://localhost:6006。
应用自己的数据集完成迁移学习
step1:定义新的datasets数据集文件
在slim/datasets/文件夹下 添加一个python文件,直接复制一份flowers.py,重命名为“humanFacial.py”(这个名字可根据你实际的数据集名字来更改,我用的是人脸表情数据集)
需要对赋值生成后的satellite.py内容做如下修改:
_FILE_PATTERN = 'flowers_%s_*.tfrecord'
更改为
_FILE_PATTERN = 'satellite_%s_*.tfrecord' #这个主要是根据你之前制作的tfrecord文件名来改的,我制作的训练文件为satellite_train_00000-of-00002.tfrecord和satellite_train_00001-of-00002.tfrecord,验证文件为satellite_validation_00000-of-00002.tfrecord,satellite_validation_00001-of-00002.tfrecord。
SPLITS_TO_SIZES = {'train': 3320, 'validation': 350}
更改为
SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200} #这个根据自己训练和验证样本数量来改,我的训练数据是800张图/类,共6类,验证集时200张/类,共6类;
_NUM_CLASSES = 5
更改为
_NUM_CLASSES = 6 #实际训练类别为6类;
还需要对satellite.py文件中的'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),这行代码做更改,由于用的数据集源文件都是XXXX.jpg格式,因此将默认的图像格式转为jpg,更改后为'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 至此,对satellite.py文件完成制作与更改(其源码如下):
satellite.py
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset.
The dataset scripts used to create the dataset can be found at:
tensorflow/models/slim/datasets/download_and_convert_flowers.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from datasets import dataset_utils
slim = tf.contrib.slim
_FILE_PATTERN = 'satellite_%s_*.tfrecord'
SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200}
_NUM_CLASSES = 6
_ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying size.',
'label': 'A single integer between 0 and 4',
}
def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
"""Gets a dataset tuple with instructions for reading flowers.
Args:
split_name: A train/validation split name.
dataset_dir: The base directory of the dataset sources.
file_pattern: The file pattern to use when matching the dataset sources.
It is assumed that the pattern contains a '%s' string so that the split
name can be inserted.
reader: The TensorFlow reader type.
Returns:
A `Dataset` namedtuple.
Raises:
ValueError: if `split_name` is not a valid train/validation split.
"""
if split_name not in SPLITS_TO_SIZES:
raise ValueError('split name %s was not recognized.' % split_name)
if not file_pattern:
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
# Allowing None in the signature so that dataset_factory can use the default.
if reader is None:
reader = tf.TFRecordReader
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir)
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
num_classes=_NUM_CLASSES,
labels_to_names=labels_to_names)
step2:注册数据库
接下来对slim/datasets/dataset_factory.py文件做更改,注册下satellite数据库;修改之处如下(添加了两行红色字体代码):
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import satellite
datasets_map = {
'cifar10': cifar10,
'flowers': flowers,
'imagenet': imagenet,
'mnist': mnist,
'satellite': satellite,
}
step3:准备训练文件夹
在slim文件夹下新建如下目录文件夹,并将对应的文件放在相应目录下
slim/
satellite/
data/
satellite_train_00000-of-00002.tfrecord
satellite_train_00001-of-00002.tfrecord
satellite_validation_00000-of-00002.tfrecord
satellite_validation_00001-of-00002.tfrecord
label.txt
pretrained/
inception_v3.ckpt
train_dir/
data文件夹下存放你制作的tfrecord训练测试文件和标签名;
pretrained文件夹下存放官网训练的权重文件;下载地址:http:/!download. tensorflow .org/models/inception _ v3_2016 _ 08 _ 28.tar.gz
train_dir文件夹下存放你训练得到的模型和日志;
step4-1:在现有模型结构上fine-tune
开始训练,在slim文件夹下,运行如下指令可开始训练(主要是训练逻辑层):
python train_image_classifier.py \
--train_dir=satellite/train_dir \
--dataset_name=satellite \
--dataset_split_name=train \
--dataset_dir=satellite/data \
--model_name=inception_v3 \
--checkpoint_path=satellite/pretrained/inception_v3.ckpt \
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--max_number_of_steps=100000 \
--batch_size=32 \
--learning_rate=0.001 \
--learning_rate_decay_type=fixed \
--save_interval_secs=300 \
--save_summaries_secs=2 \
--log_every_n_steps=10 \
--optimizer=rmsprop \
--weight_decay=0.00004
命令参数解析如下:
• --trainable_ scopes=Inception V3/Logits,InceptionV3/ AuxLogits :首先来解 释参数trainable_scopes 的作用,因为非常重要。 trainable_scopes 规定了在模型中fine-tune变量的范围 。 这里的设定表示只对 InceptionV3/Logits, Inception V3/ AuxLogits 两个变量进行微调,其他变量都保持不动 。 Inception V3/Logits,Inception V3/ AuxLogits 就相当于在网络中的 fc8 ,它们是 Inception V3的“末端层” 。 如果不设定 trainable_scopes , 就会对模型中所有的参数进行训练。
• --train_dir=satellite/train_dir:表明会在 satellite/train_dir目录下保存日志和checkpoint。
• --dataset_name=satellite、 --dataset_split_ name=train: 指定训练的数据集。
• --dataset_dit=satellite/data:指定训练数据集保存的位置。
• --model_ name=inception _ v3 :使用的模型名称。
• --checkpoint_path=satellite/pretrained/inception_v3.ckpt:预训练模型的保存位置。
• --checkpoint_exclude_scopes=Inception V3/Logits,InceptionV3/ AuxLogits : 在恢复预训练模型时,不恢复这两层。正如之前所说,这两层是 Inception V3 模型的末端层,对应着 ImageNet 数据集的 1000 类,和相当前的数据集不符,因此不要去恢复它。
• --max_number_of_steps 100000:最大的执行步数。
• --batch_size=32:每步使用的 batch 数量。
• --learning_rate=0.001 : 学习率。
• --learning_rate_decay_type=fixed:学习率是否自动下降,此处使用固定的学习率。
• --save_interval_secs=300:每隔 300s,程序会把当前模型保存到train_dir中。 此处就是目录 satellite/train_dir。
• --save_summaries_secs=2:每隔 2s,就会将日志写入到 train_dir 中。可以用 TensorBoard 查看该日志。此处为了方便观察,设定的时间间隔较多,实际训练时,为了性能考虑,可以设定较长的时间间隔。
• --log_every_n_steps=10:每隔10步,就会在屏上打出训练信息。
• --optimizer=msprop:表示选定的优化器。
• --weight_decay=0.00004:选定的 weight_decay 值。 即模型中所高参数的 二次正则化超参数。
以上命令是只训练末端层 InceptionV3/Logits,Inception V3/ AuxLogits ,还 可以使用以下命令对所高层进行训练:
step4-2:训练整个模型权重数据
使用以下命令对所有层进行训练:
去掉 了--trainable_scopes 参数
python train_image_classifier.py \
--train_dir=satellite/train_dir \
--dataset_name=satellite \
--dataset_split_name=train \
--dataset_dir=satellite/data \
--model_name=inception_v3 \
--checkpoint_path=satellite/pretrained/inception_v3.ckpt \
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
--max_number_of_steps=100000 \
--batch_size=32 \
--learning_rate=0.001 \
--learning_rate_decay_type=fixed \
--save_interval_secs=300 \
--save_summaries_secs=2 \
--log_every_n_steps=10 \
--optimizer=rmsprop \
--weight_decay=0.00004
当train_image_classifier.py程序启动后,如果训练文件夹(即satellite/train_dir)里没再已经保存的模型,就会加载 checkpoint_path 中的预训练模型,紧接着,程序会把初始模型保存到 train_dir中 ,命名为 model.ckpt-0, 0 表示第 0 步。 这之后,每隔 5min (参数一save_interval_secs=300 指定了每隔 300s 保存一次,即 5min )。 程序还会把当前模型保存到同样的文件夹中 , 命名恪式和第一次保存的格式一样。 因为模型比较大,程序只会保留最新的 5 个模型。
此外,如果中断了程序并再次运行,程序会首先检查 train_dir 中有无已经保存的模型,如果有,就不会去加载 checkpoint_path 中的预训练模型, 而是直接加载 train_dir 中已经训练好的模型,并以此为起点进行训练。 Slim 之所以这样设计,是为了在微调网络的时候,可以方便地按阶段手动调整学习率等参数。
至此用slim工具包做fine-tune或重新训练的步骤就完成了。
导出模型
保存包含模型体系结构的GraphDef。
要使用由slim定义的模型名称,请运行:
$ python export_inference_graph.py \
--alsologtostderr \
--model_name=inception_v3 \
--output_file=/tmp/inception_v3_inf_graph.pb
$ python export_inference_graph.py \
--alsologtostderr \
--model_name=mobilenet_v1 \
--image_size=224 \
--output_file=/tmp/mobilenet_v1_224.pb
整合导出的Graph
如果然后要将结果模型与您自己的或预先训练的检查点一起用作mobile model,则可以运行freeze_graph以使用以下内容将变量内嵌为常量:
bazel build tensorflow/python/tools:freeze_graph
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/tmp/inception_v3_inf_graph.pb \
--input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \
--input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \
--output_node_names=InceptionV3/Predictions/Reshape_1
输出节点名称将根据型号而有所不同,但您可以使用summarize_graph工具检查和估计它们。
bazel build tensorflow/tools/graph_transforms:summarize_graph
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
--in_graph=/tmp/inception_v3_inf_graph.pb
参考: