上一篇文章 TensorFlow 使用预训练模型 ResNet-50 介绍了使用 tf.contrib.slim 模块来简单导入 TensorFlow 预训练模型参数,进而使用 slim.learning.train 函数来 fine tuning 模型。这一篇文章,在预告的多任务多标签之前,再插入一篇简单的文章,延续 TensorFlow 导入预训练模型并精调神经网络参数的这个主题。这篇文章使用的方法和上一篇的明显不同,不过方法依旧非常简单,只需要使用类 tf.train.Saver 及其成员函数 .restore 即可。
模型定义、训练及预训练参数导入
首先,我们来梳理一下要使用预训练模型来微调神经网络需要做的事情:1.定义神经网络结构;2.导入预训练模型参数;3.读取数据进行训练;4.使用 Tensorboard 可视化训练过程(此处略,留到以后单独讲)。清楚了以上步骤之后,我们来看如下全部代码:
# -*- coding: utf-8 -*-
"""
Created on Tue May 8 13:58:54 2018
@author: shirhe-lyh
"""
import numpy as np
import os
import tensorflow as tf
from tensorflow.contrib.slim import nets
slim = tf.contrib.slim
def get_next_batch(batch_size=64, ...):
"""Get a batch set of training data.
Args:
batch_size: An integer representing the batch size.
...: Additional arguments.
Returns:
images: A 4-D numpy array with shape [batch_size, height, width,
num_channels] representing a batch of images.
labels: A 1-D numpy array with shape [batch_size] representing
the groundtruth labels of the corresponding images.
"""
... # Get images and the corresponding groundtruth labels.
return images, labels
if __name__ == '__main__':
# Specify which gpu to be used
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
batch_size = 64
num_classes = 5
num_step