所以我做了一些测试,发现了以下关于这个问题的信息。
由于我一直试图重用我创建的模型,我不得不使用tf.global_variables_初始值设定项(). 通过这样做,它覆盖了我导入的图形,所有的值都是随机的,这解释了不同的网络输出。这仍然给我留下了一个需要解决的问题:如何加载我的网络?目前我使用的解决方法还不是最佳的,但它至少允许我使用我保存的模型。张量流允许给函数和所使用的张量指定唯一的名称。通过这样做,我可以通过图表访问它们:with tf.Session() as sess:
saver = tf.train.import_meta_graph('path to .meta')
saver.restore(sess, tf.train.latest_checkpoint('path to checkpoints'))
graph = tf.get_default_graph()
graph.get_tensor_by_name('name:0')
使用这个方法,我可以访问所有保存的值,但它们是分开的!这意味着每次操作我都有1倍的权重和1倍的偏差,这导致了一堆新的变量。如果您不知道这些名称,请使用以下内容:
^{pr2}$
这将打印集合名称(变量存储在集合中)print(graph.get_collection('name'))
这允许我们访问集合,查看变量的名称/键是什么。在
这导致了另一个问题。我不能再使用我的模型,因为全局变量初始值设定项覆盖了所有内容。因此,我不得不用我之前得到的权重和偏差手动重新定义整个模型。在
不幸的是,这是我唯一能想到的。如果有人有更好的主意,请告诉我。在
整件事情看起来是这样的:imports...
placeholders for data...
def my_network(data):
## network definition with tf functions ##
return output
def train_my_net():
prediction = my_network(data)
cost function
optimizer
with tf.Session() as sess:
for i in how many epochs i want:
training routine
save
def use_my_net():
prediction = my_network(data)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph('path to .meta')
saver.restore(sess, tf.train.latest_checkpoint('path to checkpoints'))
print(sess.run(prediction.eval(feed_dict={placeholder:data})))
graph = tf.get_default_graph()