【tensorflow】ckpt模型restore

tf版本:

项目目录内容如下

checkpoint中的内容如下

model_checkpoint_path: "mymodel-290"
all_model_checkpoint_paths: "mymodel-260"
all_model_checkpoint_paths: "mymodel-270"
all_model_checkpoint_paths: "mymodel-280"
all_model_checkpoint_paths: "mymodel-290"

model save与restore代码,save_test.py

# -*- coding: utf-8 -*-
import time
import os
import pickle
import random
import numpy as np
import tensorflow as tf
import sys

def train_model():
    # prepare the data
    x_data = np.random.rand(100).astype(np.float32)
    print(x_data)
    y_data = x_data * 0.1 + 0.2
    print(y_data)

    # define the weights
    W = tf.Variable(tf.random_uniform([1], -20.0, 20.0), dtype=tf.float32, name='w')
    b = tf.Variable(tf.random_uniform([1], -10.0, 10.0), dtype=tf.float32, name='b')
    y = W * x_data + b

    # define the loss
    loss = tf.reduce_mean(tf.square(y - y_data))
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)

    # save model
    saver = tf.train.Saver(max_to_keep=4)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print("------------------------------------------------------")
        print("before the train, the W is %6f, the b is %6f" % (sess.run(W), sess.run(b)))

        for epoch in range(300):
            if epoch % 10 == 0:
                print("------------------------------------------------------")
                print("after epoch %d, the loss is %6f" % (epoch, sess.run(loss)))
                print("the W is %f, the b is %f" % (sess.run(W), sess.run(b)))
                # 注意这里"model/mymodel",model是路径文件夹,mymodel是保存的model的名字
                saver.save(sess, "model/mymodel", global_step=epoch)
                print("save the model")
            sess.run(train_step)
        print("------------------------------------------------------")

#
def load_model():
    with tf.Session() as sess:
        # 注意这里"model/mymodel",model是路径文件夹,mymodel-290是保存的model的名字
        saver = tf.train.import_meta_graph('model/mymodel-290.meta')
        saver.restore(sess, tf.train.latest_checkpoint("model/"))
        print(sess.run('w:0'))
        print(sess.run('b:0'))

# 训练与保存model
train_model()

# 加载与恢复模型
load_model()

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值