tensorflow 模型保存和恢复模型再训练,或者使用模型进行预测

初始目录结构

save.py代码

import os
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import sys

np.random.seed(1)
data_x = np.random.rand(100, 1)

np.random.seed(2)
data_y = np.random.rand(100, 1)

save_dir_path = 'model'
save_file_name = 'model.cpkt'

with tf.name_scope('myPlaceholder') as scope:
    x = tf.placeholder(dtype=tf.float64, shape=(None, 1), name='x')


def model():
    return tf.add(tf.multiply(a, x), b, name="linear_model")


def process():
    for _ in tqdm(range(1000)):
        _, value_a, value_b, value_loss = sess.run([train, a, b, loss_function], feed_dict=feed_dict_x)

    print('训练之后', value_a, value_b, 'loss', value_loss)

    saver.save(sess, save_path=os.path.join(save_dir_path, save_file_name))


if __name__ == '__main__':

    """
    如果不是gpu,将config去掉,使用默认的tf.Session()创建session
    """
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    sess = tf.Session(config=config)
    saver = None

    if not tf.train.checkpoint_exists(save_dir_path):
        os.mkdir(save_dir_path)

        a = tf.Variable(np.random.rand(1, 1), dtype=tf.float64, name='a')
        b = tf.Variable(np.random.rand(1, 1), dtype=tf.float64, name='b')
        saver = tf.train.Saver(max_to_keep=1)
        sess.run(tf.global_variables_initializer())

        model_output = model()
        loss_function = tf.reduce_mean(tf.square(model_output - data_y), name='loss')
        train = tf.train.GradientDescentOptimizer(learning_rate=0.002).minimize(loss_function, name='minimize')
        feed_dict_x = {x: data_x}
    else:
        saver = tf.train.import_meta_graph(os.path.join(os.getcwd(), save_dir_path, "model.cpkt.meta"))
        saver.restore(sess, os.path.join(save_dir_path, save_file_name))
        graph = tf.get_default_graph()

        a = graph.get_tensor_by_name('a:0')
        b = graph.get_tensor_by_name('b:0')
        print('恢复模型', sess.run(a), sess.run(b))
        loss_function = graph.get_tensor_by_name('loss:0')
        train = graph.get_operation_by_name('minimize')
        feed_dict_x = {graph.get_tensor_by_name("myPlaceholder/x_1:0"): data_x}

        op = input("""选择继续训练或者使用模型进行预测(1:训练   2:预测)""")

        if op is '1':
            pass

        elif op is '2':
            while True:
                input_x = np.array([[input("输入x:")]], dtype=np.float64)
                feed_dict_x = {graph.get_tensor_by_name("myPlaceholder/x_1:0"): input_x}
                output = sess.run(graph.get_tensor_by_name("linear_model:0"), feed_dict=feed_dict_x)
                print(output)
        else:
            sys.exit()

    process()
    sess.close()

上面的恢复模型是通过加载已经持久化的图,而下面的是通过已经定义图上的运算。区别在于使用上面的代码恢复模型时,即使是 注释掉了model()方法,依旧能正常运行,因为不依靠已经定义好的运算,下面的代码在恢复模型时,只把变量的值加载了进来,需要重复定义图上的运算。

import os
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import sys

np.random.seed(1)
data_x = np.random.rand(100, 1)

np.random.seed(2)
data_y = np.random.rand(100, 1)

save_dir_path = 'model'
save_file_name = 'model.cpkt'


def model():
    return tf.add(tf.multiply(a, x), b, name="linear_model")


def process():
    for _ in tqdm(range(1000)):
        _, value_a, value_b, value_loss = sess.run([train, a, b, loss_function], feed_dict=feed_dict_x)

    print('训练之后', value_a, value_b, 'loss', value_loss)

    saver.save(sess, save_path=os.path.join(save_dir_path, save_file_name))


if __name__ == '__main__':

    """
    如果不是gpu,将config去掉,使用默认的tf.Session()创建session
    """
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    with tf.name_scope('myPlaceholder') as scope:
        x = tf.placeholder(dtype=tf.float64, shape=(None, 1), name='x')
    a = tf.Variable(np.random.rand(1, 1), dtype=tf.float64, name='a')
    b = tf.Variable(np.random.rand(1, 1), dtype=tf.float64, name='b')

    model_output = model()
    loss_function = tf.reduce_mean(tf.square(model_output - data_y), name='loss')
    train = tf.train.GradientDescentOptimizer(learning_rate=0.002).minimize(loss_function, name='minimize')
    feed_dict_x = {x: data_x}

    saver = tf.train.Saver(max_to_keep=1)

    if not tf.train.checkpoint_exists(save_dir_path):
        os.mkdir(save_dir_path)
        sess.run(tf.global_variables_initializer())
    else:
        saver.restore(sess, os.path.join(save_dir_path, save_file_name))
        op = input("""选择训练或者使用模型进行预测(1:训练   2:预测)""")

        if op is '1':
            pass

        elif op is '2':
            while True:
                input_x = np.array([[input("输入x:")]], dtype=np.float64)
                feed_dict_x = {x: input_x}
                output = sess.run(model(), feed_dict=feed_dict_x)
                print(output)
        else:
            sys.exit()
    process()
    sess.close()

 

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值