tensorflow自定义op_TensorFlow 自定义模型导出:将 .ckpt 格式转化为 .pb 格式

本文承接上文 TensorFlow-slim 训练 CNN 分类模型(续),阐述通过 tf.contrib.slim 的函数 slim.learning.train 训练的模型,怎么通过人为的加入数据入口(即占位符)来克服无法用于图像推断的问题。要解决这个问题,最简单和最省时的方法是模仿。我们模仿的代码是 TensorFlow 实现的目标检测 API 中的文件 exporter.py,该文件的目的正是要将 TensorFlow-slim 训练的目标检测模型由 .ckpt 格式转化为.pb 格式,而且其代码中人为添加占位符的操作也正是我们需求的。坦白的说,我会用 TensorFlow 的 tf.contrib.slim 模块来构建和训练模型正是受 TensorFlow models 项目的影响,当时我需要训练目标检测器,因此变配置了 models 这个子项目,并且从头到尾的阅读了其中 object_detection 中的 Faster RCNN 的源代码,切实感受到了 slim 模块的简便和高效(学习 TensorFlow 最好的办法除了查阅文档之外,便是看 models 中各种项目的源代码)。

言归正传,现在我们回到主题,怎么加入占位符,将前一篇文章训练的 CNN 分类器用于图像分类。这个问题在我们知道通过模仿 exporter.py 就可以解决它的时候,就变得异常简单了。我们先来理顺一下解决这个问题的逻辑:

1.定义数据入口,即定义占位符 inputs = tf.placeholder(···);

2.将模型作用于占位符,得到数据出口,即分类结果;

3.将训练文件从 .ckpt 格式转化为 .pb 格式。

按照这个逻辑顺序,下面我们详细的来看一下自定义模型导出,即模型格式转化的代码(命名为 exporter.py,如果没有特别说明,exporter.py 指的都是我们修改 TensorFlow 目标检测中的 exporter.py 后的自定义文件):

#!/usr/bin/env python3

# -*- coding: utf-8 -*-

"""

Created on Fri Mar 30 15:13:27 2018

@author: shirhe-lyh

"""

"""Functions to export inference graph.

Modified from: TensorFlow models/research/object_detection/export.py

"""

import logging

import os

import tempfile

import tensorflow as tf

from tensorflow.core.protobuf import saver_pb2

from tensorflow.python import pywrap_tensorflow

from tensorflow.python.client import session

from tensorflow.python.framework import graph_util

from tensorflow.python.platform import gfile

from tensorflow.python.saved_model import signature_constants

from tensorflow.python.training import saver as saver_lib

slim = tf.contrib.slim

# TODO: Replace with freeze_graph.freeze_graph_with_def_protos when

# newer version of Tensorflow becomes more common.

def freeze_graph_with_def_protos(

input_graph_def,

input_saver_def,

input_checkpoint,

output_node_names,

restore_op_name,

filename_tensor_name,

clear_devices,

initializer_nodes,

variable_names_blacklist=''):

"""Converts all variables in a graph and checkpoint into constants."""

del restore_op_name, filename_tensor_name # Unused by updated loading code.

# 'input_checkpoint' may be a prefix if we're using Saver V2 format

if not saver_lib.checkpoint_exists(input_checkpoint):

raise ValueError(

"Input checkpoint ' + input_checkpoint + ' does not exist!")

if not output_node_names:

raise ValueError(

'You must supply the name of a node to --output_node_names.')

# Remove all the explicit device specifications for this node. This helps

# to make the graph more portable.

if clear_devices:

for node in input_graph_def.node:

node.device = ''

with tf.Graph().as_default():

tf.import_graph_def(input_graph_def, name='')

config = tf.ConfigProto(graph_options=tf.GraphOptions())

with session.Session(config=config) as sess:

if input_saver_def:

saver = saver_lib.Saver(saver_def=input_saver_def)

saver.restore(sess, input_checkpoint)

else:

var_list = {}

reader = pywrap_tensorflow.NewCheckpointReader(

input_checkpoint)

var_to_shape_map = reader.get_variable_to_shape_map()

for key in var_to_shape_map:

try:

tensor = sess.graph.get_tensor_by_name(key + ':0')

except KeyError:

# This tensor doesn't exist in the graph (for example

# it's 'global_step' or a similar housekeeping element)

# so skip it.

continue

var_list[key] = tensor

saver = saver_lib.Saver(var_list=var_list)

saver.restore(sess, input_checkpoint)

if initializer_nodes:

sess.run(initializer_nodes)

variable_names_blacklist = (variable_names_blacklist.split(',') if

variable_names_blacklist else None)

output_graph_def = graph_util.convert_variables_to_constants(

sess,

input_graph_def,

output_node_names.split(','),

variable_names_blacklist=variable_names_blacklist)

return output_graph_def

def replace_variable_values_with_moving_averages(graph,

current_checkpoint_file,

new_checkpoint_file):

"""Replaces variable values in the checkpoint with their moving averages.

If the current checkpoint has shadow variables maintaining moving averages

of the variables defined in the graph, this function generates a new

checkpoint where the variabl

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值