Tensorflow Slim微调模型

一、原理

在自己的数据集上训练一个新的深度学习模型时,一般采取在预训练ImageNet上进行微调的方法。什么是微调?这里以VGG16为例进行讲解。

VGG16的结构为卷积+全连接层。卷积层分为5个部分共13层,即conv1~conv5。还有三层全连接层,即fc6、fc7、fc8。卷积层加上全连接层合起来一共为16层。如果要将VGG16的结构用于一个新的数据集,首先要去掉fc8这一层。原因是fc8层的输入是fc7层的特征,输出是1000类的概率,这1000类正好对应了ImageNet模型中的1000个类别。在自己的数据中,类别数一般不是1000类,因此fc8层的结构在此时是不适用的。必须将fc8层去掉,重新采用符合数据集类别数的全连接层,作为新的fc8.比如数据集为5类,那么新的fc8的输出也应当是5类。

  此外,在训练的时候,网络的参数的初始值并不是随机化生成的,而是采用VGG16在ImageNet上已经训练好的参数作为训练的初始值。这样做的原因在于,在ImageNet数据集上训练过的VGG16的参数已经包含了大量有用的卷积过滤器,与其从零开始初始化VGG16的所有参数,不如使用已经训练好的参数当作训练的起点。这样做不仅可以节约大量训练时间,而且有助于分类起性能的提高。

     载入VGG16的参数后,就可以开始训练了。此时需要指定训练层数的范围。一般来说,可以选择以下几种范围进行训练:

  • 只训练fc8.训练范围一定要包含fc8这一层。之前说过,fc8的结构被调整过,因此它的参数不能直接从ImageNet预训练模型中取得。可以只训练fc8,保持其他层的参数不动。这就相当于将VGG16当作一个特征提取器,用fc7层提取的特征做一个softmax模型分类。这样做的好处是训练速度块,但往往性能不会太好。
  • 训练所有参数。还可以对网络中的所有参数进行训练,这种方法的训练速度可能比较慢,但是能取得较高的性能,可以充分发挥深度模型的威力。
  • 训练部分参数。通常是固定浅层参数不变,训练深层参数。如固定conv1、conv2部分的参数不训练,只训练conv3、conv4、conv5、fc6、fc7、fc8的参数。

二、数据集准备

将jpg格式样本集合转化为tfrecord格式。

三、使用Tensorflow Slim微调模型

slim是google公司公布的一个图像分类工具包,不仅定义了一些方便的接口,还提供了很多ImageNet数据集上常用的网络结构和预训练模型。包括VGG16\VGG19、Inception v1~v4、ResNet 50、ResNet101、MobileNet在内大多数常用模型的结构以及预训练模型,更多的模型会被持续添加进来。

1)下载Tensorflow Slim的源代码

git clone https://github.com/tensorflow/models.git

找到models/research/slim文件夹。

2)定义新的datasets文件

在slim/datasets中,定义了所有可以使用的数据库,为了使用我之前创建的tfrecord数据进行训练,必须要在datasets中定义新的数据库如handGesturePic。

首先在datasets/目录下新建一个文件handGesturePic.py,并将flowers.py文件中的内容复制到handGesturePic.py中。然后修改以下几处内容。

_FILE_PATTERN='handGesturePic_%s_*.tfrecord'//改成自己的图片的命名

SPLITS_TO_SIZES={‘train’:9488,'validation':2000}//训练集和测试集的总数目

_NUM_CLASSES=2

第二处修改:image/format部分

‘image/format’:tf.FixedLenFeature((),tf.string,default_value='jpg').//定义图片的默认格式。

修改完handGesturePic.py之后,还需要在同目录的data_factory.py文件中注册handGesturePic数据库。

添加以下内容:from datasets import handGesturePic

datasets_map={

’ cifarlO ’: cifarlO,

’ flowers ’: flowers,

’ image net ’: imagenet,

’ mnist ’: mnist,

‘handGesturePic’:handGesturePic,}

3)准备训练文件夹

在slim中新建 handGesturePic目录,在这个目录中进行以下操作:

新建一个data目录,将之前生成的5个转换好的训练数据复制进去(4个.tfrecord,1个label.txt)。

新建一个空的train_dir目录,用来保存训练过程中的日志和模型。

新建一个pretrained目录,在slim的GitHubi页面找到Inception-V3模型的下载地址http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz,下载并解压后,得到 inception_v3.ckpt文件,将该文件复制到pretrained目录下。

4)开始训练

在slim文件夹下,运行以下命令就可以开始训练了:

python train_image_calssifier.py –train_dir=handGesturePic/train_dir –dataset_name=handGesturePic –dataset_split_name=train –dataset_dir=handGesturePic/data –model_name=inception_v3 –checkpoint_path=handGesturePic/pretrained/inception_v3.ckpt –checkpoint_exclude_scopes=InceptionV3/Logits, InceptionV3/AuxLogits –trainable_scopes=InceptionV3/Logits, InceptionV3/AuxLogits –max_number_of_steps=10000 –batch_size=32 –learning_rate=0.001 –learning _rate_decay_type=fixed –save_interval_secs=300 –save_summaries_secs=2 –log_every_n_steps=10 –optimizer=rmsprop –weight_decay=0.00004

参数解释:

trainable_scopes=InceptionV3/Logits, InceptionV3/AuxLogits。trainable_scopes规定了在模型中微调变量的范围。这里的设定表示只对 InceptionV3/Logits, InceptionV3/AuxLogits两个变量进行微调,其他变量都保持不动。 InceptionV3/Logits, InceptionV3/AuxLogits是inception V3的末端层。只对最后一层分类层进行训练,比如原来是1000类,现在训练的只是2类。如果不设定trainable_scopes,就只会对模型中所有的参数进行训练。

5)验证模型准确率

执行脚本:python eval_image_classifier.py –checkpoint_path=handGesturePic/train_dir –eval_dir=handGesturePic/eval_dir –dataset_name=handGesturePic –dataset_split_name=validation –dataset_dir=handGesturePic/data –model_name=inception_v3

修改eval_image_classifier.py中’ Accuracy': slim.metrics.streaming_accuracy(predicti。ns, labels),

’ Recall_S ’: slim.metrics.streaming_reca ll_at_k(

logits, labels, 5),//确定输出前几个的准确率,因为我只有2类,所以改为'1'

模型的准确率和召回率均为98%。

6)Tensorboard可视化

命令:tensorboard –logdir handGesturePic/train_dir

可以看到损失变化的曲线。当损失曲线比较平缓,收敛较慢时,可以考虑增大学习率,以加快收敛速度;如果损失曲线波动较大,无法收敛,就可能是学习率过大,此时就可以尝试适当减少学习率。

7)导出模型并对单张图片进行识别

首先在slim文件夹下运行:

python export_inference_graph.py –-alsologtostderr --model_name=inception_v3 –output_file=handGesturePic/inception_v3_inf_graph.pb –dataset_name handGesturePic

这个命令会在handGesturePic文件夹生成一个inception_v3_inf_graph.pb文件。该文件只保存了inception v3的网络结构,并不包含训练得到的模型参数。需要将checkpoint中的模型参数保存进来。

Python freeze_graph.py –input_graph handGesturePic/inception_v3_inf_graph.pb –input_checkpoint handGesturePic/train_dir/model.ckpt-5000 –input_binary true –output_node_names InceptionV3/Predictions/Reshape_1 –output_graph handGesturePic/frozen_graph.pb

如何使用导出的frozen_graph.pb来对单张图片进行预测?编写一个classify_image_inception_v3.py脚本来完成这件事:

Python classify_image_inception_v3.py –model_path handGesturePic/frozen_graph.pb –label_path data_prepare/Pic/label.txt –image_file test_image.jpg

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值