tf版本:
项目目录内容如下
checkpoint中的内容如下
model_checkpoint_path: "mymodel-290"
all_model_checkpoint_paths: "mymodel-260"
all_model_checkpoint_paths: "mymodel-270"
all_model_checkpoint_paths: "mymodel-280"
all_model_checkpoint_paths: "mymodel-290"
model save与restore代码,save_test.py
# -*- coding: utf-8 -*-
import time
import os
import pickle
import random
import numpy as np
import tensorflow as tf
import sys
def train_model():
# prepare the data
x_data = np.random.rand(100).astype(np.float32)
print(x_data)
y_data = x_data * 0.1 + 0.2
print(y_data)
# define the weights
W = tf.Variable(tf.random_uniform([1], -20.0, 20.0), dtype=tf.float32, name='w')
b = tf.Variable(tf.random_uniform([1], -10.0, 10.0), dtype=tf.float32, name='b')
y = W * x_data + b
# define the loss
loss = tf.reduce_mean(tf.square(y - y_data))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
# save model
saver = tf.train.Saver(max_to_keep=4)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print("------------------------------------------------------")
print("before the train, the W is %6f, the b is %6f" % (sess.run(W), sess.run(b)))
for epoch in range(300):
if epoch % 10 == 0:
print("------------------------------------------------------")
print("after epoch %d, the loss is %6f" % (epoch, sess.run(loss)))
print("the W is %f, the b is %f" % (sess.run(W), sess.run(b)))
# 注意这里"model/mymodel",model是路径文件夹,mymodel是保存的model的名字
saver.save(sess, "model/mymodel", global_step=epoch)
print("save the model")
sess.run(train_step)
print("------------------------------------------------------")
#
def load_model():
with tf.Session() as sess:
# 注意这里"model/mymodel",model是路径文件夹,mymodel-290是保存的model的名字
saver = tf.train.import_meta_graph('model/mymodel-290.meta')
saver.restore(sess, tf.train.latest_checkpoint("model/"))
print(sess.run('w:0'))
print(sess.run('b:0'))
# 训练与保存model
train_model()
# 加载与恢复模型
load_model()