TensorFlow 训练模型的保存和加载

我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。

  1. Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值
  2. 只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。
  3. 为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件

模型保存

下面是一个简单的保存训练模型的例子

import tensorflow as tf
import numpy as np

train_steps = 100 #表示训练的次数
checkpoint_steps = 50 #每50次训练保存一次模型checkpoints
checkpoint_dir = './model/'#表示模型文件的保存路径
learn_rate = 0.1

x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4

w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b

loss = tf.reduce_mean(tf.square(y - y_predict))
train= tf.train.AdamOptimizer(learn_rate).minimize(loss)

saver = tf.train.Saver()  # defaults to saving all variables
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
init = tf.initialize_all_variables()

config = tf.ConfigProto(allow_soft_placement=True, allow_soft_placement=True)
config.gpu_options.per_process_gpu_memory_fraction = 0.4  #占用40%显存
sess = tf.Session(config=config)
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # 使用 GPU 0,1
with tf.Session() as sess:
    sess.run(init)
    for i in xrange(train_steps):
        sess.run(train, feed_dict={x: x_data})
        if (i + 1) % checkpoint_steps == 0:
            saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)

说明:例子对于一个简单的网络,保存训练的结果,这是最常见的一种模型保存方式。

1. 记录设备指派情况 :  tf.ConfigProto(log_device_placement=True)

设置tf.ConfigProto()中参数log_device_placement = True ,可以获取到 operations 和 Tensor 被指派到哪个设备(几号CPU或几号GPU)上运行,会在终端打印出各项操作是在哪个设备上运行的。

2. 自动选择运行设备 : tf.ConfigProto(allow_soft_placement=True)

在tf中,通过命令 "with tf.device('/cpu:0'):",允许手动设置操作运行的设备。如果手动设置的设备不存在或者不可用,就会导致tf程序等待或异常,为了防止这种情况,可以设置tf.ConfigProto()中参数allow_soft_placement=True,允许tf自动选择一个存在并且可用的设备来运行操作。

3. 限制GPU资源使用:

为了加快运行效率,TensorFlow在初始化时会尝试分配所有可用的GPU显存资源给自己,这在多人使用的服务器上工作就会导致GPU占用,别人无法使用GPU工作的情况。tf提供了两种控制GPU资源使用的方法,一是让TensorFlow在运行过程中动态申请显存,需要多少就申请多少;第二种方式就是限制GPU的使用率。

模型加载

不需重新定义网络结构的方法: tf.train.import_meta_graph

import_meta_graph(
    meta_graph_or_file,
    clear_devices=False,
    import_scope=None,
    **kwargs
)
这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。

保存部分:

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
    sess.run(train_op)
    if step % 1000 == 0:
        # 保存checkpoint, 同时也默认导出一个meta_graph
        # graph名为'my-model-{global_step}.meta'.
        saver.save(sess, './model/model',global_step=step)

说明:

tf.add_to_collection:把变量放入一个集合,把很多变量变成一个列表

tf.get_collection:从一个结合中取出全部变量,是一个列表

tf.add_n:把一个列表的东西都依次加起来

载入部分:
checkpoint_dir = './model/'
with tf.Session() as sess: 
    graph = tf.get_default_graph()
    sess.run(tf.global_variables_initializer())

  session_conf = tf.ConfigProto(allow_safe_placement=True, log_device_placement =False)
  sess = tf.Session(config = session_conf)

  ckpt = tf.train.get_checkpoint_state(checkpoint_dir)  
  if ckpt and ckpt.model_checkpoint_path:  
      saver.restore(sess, ckpt.model_checkpoint_path)  
  else:  

  pass

  # tf.get_collection() 返回一个list. 但是这里只要第一个参数即
  y = tf.get_collection('pred_network')[0]
  # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。

  input_x = graph.get_operation_by_name('input_x').outputs[0]

  keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

  # 使用y进行预测  
  sess.run(y, feed_dict={input_x:....,  keep_prob:1.0})
因为计算图中保存了所有的operation,在载入模型后,可以恢复session并调用预测operation,新的数据可以喂到fee_dict中跑。但是必须保证新的数据和训练的数据有相同的格式和shape。

这里有两点需要注意的: 

一、 saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如 
model-10000.meta, model-10000.index, model-10000.data-00000-of-00001其中,*.meta文件保存了当前图结构,*.index文件保存了当前参数名,*.data文件保存了当前参数值。

import_meta_graph时填的就是meta文件名,我们知道权值都保存在model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、为了通过名字使用operation,你必须在原始模型中对operation和变量进行命名。在feed_dict中需要给出执行operation时需要的所有参数,所以如果你训练中使用了dropout,你这里也需要给出。模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

参考资料:

1.https://blog.csdn.net/laolu1573/article/details/70574544

2.https://blog.csdn.net/laolu1573/article/details/66971800

3.https://blog.csdn.net/lujiandong1/article/details/53301994

4.https://blog.csdn.net/dcrmg/article/details/79091941

5.https://blog.csdn.net/uestc_c2_403/article/details/72415791








评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值