本文承接上文 TensorFlow-slim 训练 CNN 分类模型(续),阐述通过 tf.contrib.slim 的函数 slim.learning.train 训练的模型,怎么通过人为的加入数据入口(即占位符)来克服无法用于图像推断的问题。要解决这个问题,最简单和最省时的方法是模仿。我们模仿的代码是 TensorFlow 实现的目标检测 API 中的文件 exporter.py,该文件的目的正是要将 TensorFlow-slim 训练的目标检测模型由 .ckpt 格式转化为.pb 格式,而且其代码中人为添加占位符的操作也正是我们需求的。坦白的说,我会用 TensorFlow 的 tf.contrib.slim 模块来构建和训练模型正是受 TensorFlow models 项目的影响,当时我需要训练目标检测器,因此变配置了 models 这个子项目,并且从头到尾的阅读了其中 object_detection 中的 Faster RCNN 的源代码,切实感受到了 slim 模块的简便和高效(学习 TensorFlow 最好的办法除了查阅文档之外,便是看 models 中各种项目的源代码)。
言归正传,现在我们回到主题,怎么加入占位符,将前一篇文章训练的 CNN 分类器用于图像分类。这个问题在我们知道通过模仿 exporter.py 就可以解决它的时候,就变得异常简单了。我们先来理顺一下解决这个问题的逻辑:
1.定义数据入口,即定义占位符 inputs = tf.placeholder(···);
2.将模型作用于占位符,得到数据出口,即分类结果;
3.将训练文件从 .ckpt 格式转化为 .pb 格式。
按照这个逻辑顺序,下面我们详细的来看一下自定义模型导出,即模型格式转化的代码(命名为 exporter.py,如果没有特别说明,exporter.py 指的都是我们修改 TensorFlow 目标检测中的 exporter.py 后的自定义文件):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 30 15:13:27 2018
@author: shirhe-lyh
"""
"""Functions to export inference graph.
Modified from: TensorFlow models/research/object_detection/export.py
"""
import logging
import os
import tempfile
import tensorflow as tf
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import saver as saver_lib
slim = tf.contrib.slim
# TODO: Replace with freeze_graph.freeze_graph_with_def_protos when
# newer version of Tensorflow becomes more common.
def freeze_graph_with_def_protos(
input_graph_def,
input_saver_def,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
clear_devices,
initializer_nodes,
variable_names_blacklist=''):
"""Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if not saver_lib.checkpoint_exists(input_checkpoint):
raise ValueError(
"Input checkpoint ' + input_checkpoint + ' does not exist!")
if not output_node_names:
raise ValueError(
'You must supply the name of a node to --output_node_names.')
# Remove all the explicit device specifications for this node. This helps
# to make the graph more portable.
if clear_devices:
for node in input_graph_def.node:
node.device = ''
with tf.Graph().as_default():
tf.import_graph_def(input_graph_def, name='')
config = tf.ConfigProto(graph_options=tf.GraphOptions())
with session.Session(config=config) as sess:
if input_saver_def:
saver = saver_lib.Saver(saver_def=input_saver_def)
saver.restore(sess, input_checkpoint)
else:
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(
input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ':0')
except KeyError:
# This tensor doesn't exist in the graph (for example
# it's 'global_step' or a similar housekeeping element)
# so skip it.
continue
var_list[key] = tensor
saver = saver_lib.Saver(var_list=var_list)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes)
variable_names_blacklist = (variable_names_blacklist.split(',') if
variable_names_blacklist else None)
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(','),
variable_names_blacklist=variable_names_blacklist)
return output_graph_def
def replace_variable_values_with_moving_averages(graph,
current_checkpoint_file,
new_checkpoint_file):
"""Replaces variable values in the checkpoint with their moving averages.
If the current checkpoint has shadow variables maintaining moving averages
of the variables defined in the graph, this function generates a new
checkpoint where the variabl