tensorflow模型继续训练 fineturn


解决tensoflow如何在已训练模型上继续训练fineturn的问题。

 

训练代码


任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf


# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])

# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')

# 操作
result = tf.matmul(x, W) + b

# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))

# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)

with tf.Session() as sess:
    # 初始化变量
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(max_to_keep=3)

    # 这里x、y给固定的值
    x_s = [[3.0]]
    y_s = [[100.0]]

    step = 0
    while (True):
        step += 1
        feed = {x: x_s, y: y_s}
        # 通过sess.run执行优化
        sess.run(train_step, feed_dict=feed)

        if step % 1000 == 0:
            print 'step: {0},  loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
            if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
                print ''
                # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
                print 'final result of {0} =  {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
                print ''
                print("模型保存的W值 : %f" % sess.run(W))
                print("模型保存的b : %f" % sess.run(b))
                break
    saver.save(sess, "./save_model/re-train", global_step=step)  # 保存模型

训练完成之后生成模型文件:

训练输出:

step: 1000,  loss: 4.89526428282e-08
step: 2000,  loss: 4.89526428282e-08
step: 3000,  loss: 4.89526428282e-08
step: 4000,  loss: 4.89526428282e-08
step: 5000,  loss: 4.89526428282e-08


final result of x×W+b =  [[99.99978]](目标值是100.0)

模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

 

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf


# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])

with tf.Session() as sess:

    # 初始化变量
    sess.run(tf.global_variables_initializer())

    # saver = tf.train.Saver(max_to_keep=3)
    saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
    saver.restore(sess, tf.train.latest_checkpoint(r'./save_model'))  # 恢复数据

    # 从保存模型中恢复变量
    graph = tf.get_default_graph()
    W = graph.get_tensor_by_name("w:0")
    b = graph.get_tensor_by_name("b:0")

    print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
    print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))

    # 操作
    result = tf.matmul(x, W) + b
    # 损失函数
    lost = tf.reduce_sum(tf.pow((result - y), 2))
    # 优化
    train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)

    # 这里x、y给固定的值
    x_s = [[3.0]]
    y_s = [[200.0]]

    step = 0
    while (True):
        step += 1
        feed = {x: x_s, y: y_s}
        # 通过sess.run执行优化
        sess.run(train_step, feed_dict=feed)
        if step % 1000 == 0:
            print 'step: {0},  loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
            if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
                print ''
                # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
                print 'final result of {0} =  {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
                print("模型保存的W值 : %f" % sess.run(W))
                print("模型保存的b : %f" % sess.run(b))
                break
    saver.save(sess, "./save_mode/re-train", global_step=step)  # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000,  loss: 1.95810571313e-07
step: 2000,  loss: 1.95810571313e-07
step: 3000,  loss: 1.95810571313e-07
step: 4000,  loss: 1.95810571313e-07
step: 5000,  loss: 1.95810571313e-07


final result of x×W+b =  [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958


从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。


总结


从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step)  # 保存模型


在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model'))  # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step)  # 保存模型

 

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

 

  • 4
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值