tensorflow 中恢复模型遇到的坑

tensorflow 模型保存和恢复模型再训练,或者使用模型进行预测 

这一篇在我恢复模型的时候,第76行代码写的是

feed_dict_x = {graph.get_tensor_by_name("myPlaceholder/x_1:0"): input_x}

跟16,17行的不一样

with tf.name_scope('myPlaceholder') as scope:
    x = tf.placeholder(dtype=tf.float64, shape=(None, 1), name='x')

其实正确的应该是才对

feed_dict_x = {graph.get_tensor_by_name("myPlaceholder/x:0"): input_x}

原因在于一开始我写的是 myPlaceholder/x:0 ,但是报了个错误

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'myPlaceholder/x_1' with dtype double and shape [1,?]
     [[{{node myPlaceholder/x_1}} = Placeholder[dtype=DT_DOUBLE, shape=[1,?], _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
     [[{{node b/_11}} = _Recv[_start_time=0, client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_64_b", tensor_type=DT_DOUBLE, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

可以看到提示要  myPlaceholder/x_1 ,后面多了 "_1" ? 于是我以为是tensorflow在设计的时候是在保存模型的时候,默认添加一个索引之类的,就使用了 myPlaceholder/x_1:0 ,而不是正确的 myPlaceholder/x:0。

为什么会这样呢?既然错误中提示有 myPlaceholder/x_1 ,说明这个名字叫做 myPlaceholder/x_1 的节点是存在的,为什么会多出了这个节点?

原来我的  x = tf.placeholder(dtype=tf.float64, shape=(1, None), name='x')  是定义在全局的,也就是说,在恢复模型的时候,恢复的图结构中有 myPlaceholder/x:0 的存在,但是在全局又定义了一个名字叫 myPlaceholder/x:0 placeholder,于是产生了冲突,其中一个被自动修改成了 myPlaceholder/x_1 ,然后在下一次训练中,多出来的 placeholder 会被保存在图中,模型就有了多个placeholder。

解决办法就是把 x = tf.placeholder(dtype=tf.float64, shape=(1, None), name='x')   定义在第一次训练时要用的地方,而不是全局。

import os
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import sys

np.random.seed(1)
data_x = np.random.rand(100, 1)

np.random.seed(2)
data_y = np.random.rand(100, 1)

save_dir_path = 'model'
save_file_name = 'model.cpkt'


def model(input):
    return tf.add(tf.multiply(a, input), b, name="linear_model")


def process():
    for _ in tqdm(range(100)):
        _, value_a, value_b, value_loss = sess.run([train, a, b, loss_function], feed_dict=feed_dict_x)

    print('训练之后', value_a, value_b, 'loss', value_loss)

    saver.save(sess, save_path=os.path.join(save_dir_path, save_file_name))


if __name__ == '__main__':

    """
    如果不是gpu,将config去掉,使用默认的tf.Session()创建session
    """
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    sess = tf.Session(config=config)

    if not tf.train.checkpoint_exists(save_dir_path):
        os.mkdir(save_dir_path)

        with tf.name_scope('myPlaceholder') as scope:
            x = tf.placeholder(dtype=tf.float64, shape=(None, 1), name='x')
        a = tf.Variable(np.random.rand(1, 1), dtype=tf.float64, name='a')
        b = tf.Variable(np.random.rand(1, 1), dtype=tf.float64, name='b')

        saver = tf.train.Saver(max_to_keep=1)
        sess.run(tf.global_variables_initializer())

        model_output = model(x)
        loss_function = tf.reduce_mean(tf.square(model_output - data_y), name='loss')
        train = tf.train.GradientDescentOptimizer(learning_rate=0.002).minimize(loss_function, name='minimize')
        feed_dict_x = {x: data_x}
    else:
        saver = tf.train.import_meta_graph(os.path.join(os.getcwd(), save_dir_path, "model.cpkt.meta"))
        saver.restore(sess, os.path.join(save_dir_path, save_file_name))
        graph = tf.get_default_graph()

        a = graph.get_tensor_by_name('a:0')
        b = graph.get_tensor_by_name('b:0')
        print('恢复模型', sess.run(a), sess.run(b))

        loss_function = graph.get_tensor_by_name('loss:0')
        train = graph.get_operation_by_name('minimize')
        feed_dict_x = {graph.get_tensor_by_name("myPlaceholder/x:0"): data_x}

        op = input("""选择继续训练或者使用模型进行预测(1:训练   2:预测)""")

        if op is '1':
            pass

        elif op is '2':
            while True:
                input_x = np.array([[input("输入x:")]], dtype=np.float64)
                feed_dict_x = {graph.get_tensor_by_name("myPlaceholder/x:0"): input_x}
                output = sess.run(graph.get_tensor_by_name("linear_model:0"), feed_dict=feed_dict_x)
                print(output)
        else:
            sys.exit()

    process()
    sess.close()

 

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值