本文主要讲解在现有常用模型基础上,如何微调模型,减少训练时间,同时保持模型检测精度。
首先介绍下Slim这个Google公布的图像分类工具包,可在github链接:modules and examples built with tensorflow 中找到slim包。
上面这个链接目录下主要包含:
official models(这个是用Tensorflow高层API做的例子模型集,建议初学者可尝试);
research models(这个是很多研究者利用tensorflow做的模型集,这个不是官方提供的,是研究者个人在维护的);
samples folder (包含代码片段和小的模型用以表述tensorflow特性,包含以博客形式存在的代码呈现);
而我说的slim工具包就在research文件夹下。
Slim库结构
不仅定义了很多接口,还提供了很多ImageNet数据集上常用的网络结构和预训练模型(包括Alexnet,CycleGAN,DCGAN,VGG16,VGG19,Inception V1~V4,ResNet 50, ResNet 101,MobileNet V1等)。
下面用slim工具包中的文件来对自己的数据集做训练,训练可分为利用已有的模型架构(如常见的VGG,Inception等的卷积,池化这些结构)来全新训练权重文件或是微调权重文件。由于很多已有的imagenet图像数据覆盖面已经很广,基于此训练的网络权重已经能提取大致的目标特征(从低微像素到高维的结构特征),所以可使用fine-tune只训练框架中某些层的权重,当然根据自己数据集做全部权重重新训练的检测效果理论会更好些,需要权衡时间成本和检测精度的需求了;
下面会依据成熟网络结构Incvption V3分别做权重文件的全部重新训练和部分重新训练(即fine-tune)来介绍;
(前提是你将slim工具库下载下来,安装了必要的tensorflow等框架;并且根据训练图像制作完成了tfrecord文件)
有关tfrecord训练文件的制作请参考:将图像制作成tfrecord
step1:定义新的datasets数据集文件
在slim/datasets/文件夹下 添加一个python文件,直接复制一份flowers.py,重命名为“satellite.py”(这个名字可根据你实际的数据集名字来更改,我用的是何大神的航拍图数据集)
需要对赋值生成后的satellite.py内容做如下修改:
_FILE_PATTERN = 'flowers_%s_*.tfrecord'
更改为
_FILE_PATTERN = 'satellite_%s_*.tfrecord' #这个主要是根据你之前制作的tfrecord文件名来改的,我制作的训练文件为satellite_train_00000-of-00002.tfrecord和satellite_train_00001-of-00002.tfrecord,验证文件为satellite_validation_00000-of-00002.tfrecord,satellite_validation_00001-of-00002.tfrecord。
SPLITS_TO_SIZES = {'train': 3320, 'validation': 350}
更改为
SPLITS_TO_SIZES = {'train': 4800, 'validation': 1200} #这个根据自己训练和验证样本数量来改,我的训练数据是800张图/类,共6类,验证集时200张/类,共6类;
_NUM_CLASSES = 5
更改为
_NUM_CLASSES = 6 #实际训练类别为6类;
还需要对satellite.py文件中的'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),这行代码做更改,由于用的数据集源文件都是XXXX.jpg格式,因此将默认的图像格式转为jpg,更改后为'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 至此,对satellite.py文件完成制作与更改(其源码如下):
satellite.py
-
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
#
-
# Licensed under the Apache License, Version 2.0 (the "License");
-
# you may not use this file except in compliance with the License.
-
# You may obtain a copy of the License at
-
#
-
# http://www.apache.org/licenses/LICENSE-2.0
-
#
-
# Unless required by applicable law or agreed to in writing, software
-
# distributed under the License is distributed on an "AS IS" BASIS,
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-
# See the License for the specific language governing permissions and
-
# limitations under the License.
-
# ==============================================================================
-
"""Provides data for the flowers dataset.
-
-
The dataset scripts used to create the dataset can be found at:
-
tensorflow/models/slim/datasets/download_and_convert_flowers.py
-
"""
-
-
from __future__
import absolute_import
-
from __future__
import division
-
from __future__
import print_function
-
-
import os
-
import tensorflow
as tf
-
-
from datasets
import dataset_utils
-
-
slim = tf.contrib.slim
-
-
_FILE_PATTERN =
'satellite_%s_*.tfrecord'
-
-
SPLITS_TO_SIZES = {
'train':
4800,
'validation':
1200}
-
-
_NUM_CLASSES =
6
-
-
_ITEMS_TO_DESCRIPTIONS = {
-
'image':
'A color image of varying size.',
-
'label':
'A single integer between 0 and 4',
-
}
-
-
-
def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
-
"""Gets a dataset tuple with instructions for reading flowers.
-
-
Args:
-
split_name: A train/validation split name.
-
dataset_dir: The base directory of the dataset sources.
-
file_pattern: The file pattern to use when matching the dataset sources.
-
It is assumed that the pattern contains a '%s' string so that the split
-
name can be inserted.
-
reader: The TensorFlow reader type.
-
-
Returns:
-
A `Dataset` namedtuple.
-
-
Raises:
-
ValueError: if `split_name` is not a valid train/validation split.
-
"""
-
if split_name
not
in SPLITS_TO_SIZES:
-
raise ValueError(
'split name %s was not recognized.' % split_name)
-
-
if
not file_pattern:
-
file_pattern = _FILE_PATTERN
-
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
-
-
# Allowing None in the signature so that dataset_factory can use the default.
-
if reader
is
None:
-
reader = tf.TFRecordReader
-
-
keys_to_features = {
-
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=
''),
-
'image/format': tf.FixedLenFeature((), tf.string, default_value=
'jpg'),
-
'image/class/label': tf.FixedLenFeature(
-
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
-
}
-
-
items_to_handlers = {
-
'image': slim.tfexample_decoder.Image(),
-
'label': slim.tfexample_decoder.Tensor(
'image/class/label'),
-
}
-
-
decoder = slim.tfexample_decoder.TFExampleDecoder(
-
keys_to_features, items_to_handlers)
-
-
labels_to_names =
None
-
if dataset_utils.has_labels(dataset_dir):
-
labels_to_names = dataset_utils.read_label_file(dataset_dir)
-
-
return slim.dataset.Dataset(
-
data_sources=file_pattern,
-
reader=reader,
-
decoder=decoder,
-
num_samples=SPLITS_TO_SIZES[split_name],
-
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
-
num_classes=_NUM_CLASSES,
-
labels_to_names=labels_to_names)
step2:注册数据库
接下来对slim/datasets/dataset_factory.py文件做更改,注册下satellite数据库;修改之处如下(添加了两行红色字体代码):
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datasets import cifar10
from datasets import flowers
from datasets import imagenet
from datasets import mnist
from datasets import satellite
datasets_map = {
'cifar10': cifar10,
'flowers': flowers,
'imagenet': imagenet,
'mnist': mnist,
'satellite': satellite,
}
step3:准备训练文件夹
在slim文件夹下新建如下目录文件夹,并将对应的文件放在相应目录下
slim/
satellite/
data/
satellite_train_00000-of-00002.tfrecord
satellite_train_00001-of-00002.tfrecord
satellite_validation_00000-of-00002.tfrecord
satellite_validation_00001-of-00002.tfrecord
label.txt
pretrained/
inception_v3.ckpt
train_dir/
data文件夹下存放你制作的tfrecord训练测试文件和标签名;
pretrained文件夹下存放官网训练的权重文件;下载地址:http:/!download. tensorflow .org/models/inception _ v3_2016 _ 08 _ 28.tar.gz
train_dir文件夹下存放你训练得到的模型和日志;
step4-1:在现有模型结构上fine-tune
开始训练,在slim文件夹下,运行如下指令可开始训练(主要是训练逻辑层):
-
python train_image_classifier.py \
-
--train_dir=satellite/train_dir \
-
--dataset_name=satellite \
-
--dataset_split_name=train \
-
--dataset_dir=satellite/data \
-
--model_name=inception_v3 \
-
--checkpoint_path=satellite/pretrained/inception_v3.ckpt \
-
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
-
--trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
-
--max_number_of_steps=100000 \
-
--batch_size=32 \
-
--learning_rate=0.001 \
-
--learning_rate_decay_type=fixed \
-
--save_interval_secs=300 \
-
--save_summaries_secs=2 \
-
--log_every_n_steps=10 \
-
--optimizer=rmsprop \
-
--weight_decay=0.00004
命令参数解析如下:
• --trainable_ scopes=Inception V3/Logits,InceptionV3/ AuxLogits :首先来解 释参数trainable_scopes 的作用,因为非常重要。 trainable_scopes 规定了在模型中fine-tune变量的范围 。 这里的设定表示只对 InceptionV3/Logits, Inception V3/ AuxLogits 两个变量进行微调,其他变量都保持不动 。 Inception V3/Logits,Inception V3/ AuxLogits 就相当于在网络中的 fc8 ,它们是 Inception V3的“末端层” 。 如果不设定 trainable_scopes , 就会对模型中所有的参数进行训练。
• --train_dir=satellite/train_dir:表明会在 satellite/train_dir目录下保存日志和checkpoint。
• --dataset_name=satellite、 --dataset_split_ name=train: 指定训练的数据集。
• --dataset_dit=satellite/data:指定训练数据集保存的位置。
• --model_ name=inception _ v3 :使用的模型名称。
• --checkpoint_path=satellite/pretrained/inception_v3.ckpt:预训练模型的保存位置。
• --checkpoint_exclude_scopes=Inception V3/Logits,InceptionV3/ AuxLogits : 在恢复预训练模型时,不恢复这两层。正如之前所说,这两层是 Inception V3 模型的末端层,对应着 ImageNet 数据集的 1000 类,和相当前的数据集不符,因此不要去恢复它。
• --max_number_of_steps 100000:最大的执行步数。
• --batch_size=32:每步使用的 batch 数量。
• --learning_rate=0.001 : 学习率。
• --learning_rate_decay_type=fixed:学习率是否自动下降,此处使用固定的学习率。
• --save_interval_secs=300:每隔 300s,程序会把当前模型保存到train_dir中。 此处就是目录 satellite/train_dir。
• --save_summaries_secs=2:每隔 2s,就会将日志写入到 train_dir 中。可以用 TensorBoard 查看该日志。此处为了方便观察,设定的时间间隔较多,实际训练时,为了性能考虑,可以设定较长的时间间隔。
• --log_every_n_steps=10:每隔10步,就会在屏上打出训练信息。
• --optimizer=msprop:表示选定的优化器。
• --weight_decay=0.00004:选定的 weight_decay 值。 即模型中所高参数的 二次正则化超参数。
以上命令是只训练末端层 InceptionV3/Logits,Inception V3/ AuxLogits ,还 可以使用以下命令对所高层进行训练:
step4-2:训练整个模型权重数据
使用以下命令对所有层进行训练:
去掉 了--trainable_scopes 参数
-
python train_image_classifier.py \
-
--train_dir=satellite/train_dir \
-
--dataset_name=satellite \
-
--dataset_split_name=train \
-
--dataset_dir=satellite/data \
-
--model_name=inception_v3 \
-
--checkpoint_path=satellite/pretrained/inception_v3.ckpt \
-
--checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
-
--max_number_of_steps=100000 \
-
--batch_size=32 \
-
--learning_rate=0.001 \
-
--learning_rate_decay_type=fixed \
-
--save_interval_secs=300 \
-
--save_summaries_secs=2 \
-
--log_every_n_steps=10 \
-
--optimizer=rmsprop \
-
--weight_decay=0.00004
当train_image_classifier.py程序启动后,如果训练文件夹(即satellite/train_dir)里没再已经保存的模型,就会加载 checkpoint_path 中的预训练模型,紧接着,程序会把初始模型保存到 train_dir中 ,命名为 model.ckpt-0, 0 表示第 0 步。 这之后,每隔 5min (参数一save_interval_secs=300 指定了每隔 300s 保存一次,即 5min )。 程序还会把当前模型保存到同样的文件夹中 , 命名恪式和第一次保存的格式一样。 因为模型比较大,程序只会保留最新的 5 个模型。
此外,如果中断了程序并再次运行,程序会首先检查 train_dir 中有无已经保存的模型,如果有,就不会去加载 checkpoint_path 中的预训练模型, 而是直接加载 train_dir 中已经训练好的模型,并以此为起点进行训练。 Slim 之所以这样设计,是为了在微调网络的时候,可以方便地按阶段手动调整学习率等参数。
至此用slim工具包做fine-tune或重新训练的步骤就完成了。
相似文章参考:https://blog.csdn.net/chaipp0607/article/details/74139895