- Copy the IPython Notebook from here
- Change
sample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b))
tosample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b), name="sample_prediction")
- Modify the code like so:
with tf.Session(graph=graph) as session:
tf.initialize_all_variables().run()
print('Initialized')
mean_loss = 0
# code omitted (no changes)
# new code below:
saver = tf.train.Saver(tf.all_variables())
saver.save(session, '/home/me/Documents/checkpoint.ckpt', write_meta_graph=False)
tf.train.write_graph(graph.as_graph_def(), '/home/me/Documents', 'graph.pb')
- Run, and verify that
checkpoint.ckpt
andgraph.pb
have been created - Run
bazel build tensorflow/python/tools:freeze_graph && bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/home/me/Documents/graph.pb --input_checkpoint=/home/me/Documents/checkpoint.ckpt --output_graph=/home/me/Documents/frozen_graph.pb --output_node_names=sample_prediction
bazel build tensorflow/python/tools:freeze_graph && \bazel-bin/tensorflow/python/tools/freeze_graph \--input_graph=some_graph_def.pb \--input_checkpoint=model.ckpt-8361242 \--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax - Verify that
frozen_graph.pb
has been created - Create a new IPython Notebook with the following code:
from __future__ import print_function
import os
import numpy as np
import random
import string
import tensorflow as tf
from tensorflow.python.platform import gfile
import zipfile
from six.moves import range
from six.moves.urllib.request import urlretrieve
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
with open('/home/me/Documents/frozen_graph.pb', "rb") as f:
graph_def.ParseFromString(f.read())
sample_prediction = tf.import_graph_def(graph_def, name="", return_elements=['sample_prediction:0'])
- Run
What have you tried?
- Originally, the graph also contained a node named
saved_sample_output
, and when I tried importing that frozen graph, the error complained aboutsaved_sample_output:0
. I tried removing the name, re-writing the checkpoint and graph files, re-freezing, and re-running the code. It then complained aboutVariable_17:0
, which, after checkinggraph.pb
, was what had originally been namedsaved_sample_output
. Other than that, I haven't been able to find anything else out. - Checked out #616 and looked at the solutions suggested for similar errors, but my
import_graph_def
never had an input map to begin with. - Removing the name parameter, or the return_elements parameter, or both, hasn't made a difference.