Tensorflow学习笔记(四):常规模型保存与加载

1. TensorFlow中的模型

首先,我们先来理解一下TensorFlow里面的持久化模型是什么。从TensorFlow 0.11版本(含)起,我们持久化保存训练模型后,在模型保存目录中一般会出现下面四个文件:

IMAGE

  • .meta文件:保存了网络(模型)的计算图,包括所有的变量(variables)、操作(operations)、集合(collections)等信息

  • .data-00000-of-00001文件和.index文件:二进制文件,保存了网络中所有的权重(weights)、偏置(biases)、梯度(gradient)和其他模型相关的参数等变量数值。.data-00000-of-00001文件保存了当前所有的训练参数变量,.index文件保存了当前参数名。在TensorFlow0.11版本之前,两者是一个.ckpt文件,现在是两个文件了

  • checkpoint:文本文件,默认保存最新的5个模型文件列表,删除前面没用的计算图和二进制文件

注:tensorflow 0.11版本之前,持久化保存训练模型只包含三个文件:checkpoint,.meta文件,.ckpt文件。


2. 保存TensorFlow模型

在TensorFlow中,如果想保存一个模型的计算图和参数值,那么就需要用到tf.train.Saver()类,使用方式如下:

第一步,在Session外生成一个模型保存对象

saver = tf.train.Saver()

第二步,以当前环境Session为参数,保存模型到本地磁盘

saver.save(sess, "/path/to/model_save_dir/model_name")

以下给出了保存TensorFlow模型的一种示例。

import tensorflow as tf
import os

# 全局变量设置,定义模型网络相关参数
your code


# 定义模型保存路径与模型保存名称,例如"/userhome/model/model_$ID/","mnist_model"
MODEL_SAVE_DIR = /path/to/model_save_dir/
MODEL_NAME = model_name

if not os.path.exists(MODEL_SAVE_PATH):
    os.makedirs(MODEL_SAVE_DIR)


# 定义模型网络架构和相关依赖
your code


# 在此声明之后的变量将不会被保存
saver = tf.train.Saver() 

# 训练模型,并保存到指定目录
with tf.Session() as sess:
    # 训练模型
    your code
    
    # 保存模型
    saver.save(sess, os.path.join(MODEL_SAVE_DIR, MODEL_NAME))

与保存模型相关的其他参数:

  • 在模型训练过程中,如果我们想在指定的迭代次数后保存模型(比如迭代1000次),我们需要使用参数global_step,在对应的步数之后调用saver.save方法:
with tf.Session() as sess:
    # 训练模型
    your code

    # 保存模型
    saver.save(sess, os.path.join(MODEL_SAVE_DIR, MODEL_NAME), global_step=1000)

这会在模型名字后面附加上‘-1000’,同时将创建下列文件:

在这里插入图片描述

  • 在模型训练过程中,如果我们想每迭代1000次就保存一次模型,我们需要使用参数global_step和一个对训练次数的条件判断,在对应的步数之后调用saver.save方法。其中TRAINING_STEPS是全部的训练次数;step是实时的训练次数计数。
with tf.Session() as sess:
    # 训练模型
    your code
    
    # 保存模型
    tf.global_variables_initializer().run()
    
    for i in range(TRAINING_STEPS):
        xs, ys = mnist.train.next_batch(BATCH_SIZE)
        _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
        if i % 1000 == 0:
            print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
            saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=step)

在这里插入图片描述
在第一次保存模型时(第1000次迭代),TensorFlow会创建.meta文件。因为计算图结构不变,所以我们无需每次都重复创建(无需在第2000、3000…或其他整千次迭代).meta文件,我们仅需保存迭代后的参数值。因此,当我们保存过一次计算图后(手动把.meta文件复制到其他路径),就可以设置参数write_meta_graph为False,不用每次都保存图结构。

with tf.Session() as sess:
    # 训练模型
    your code
    
    # 保存模型
    saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=step, write_meta_graph=False)
  • 如果我们没有在tf.train.Saver()中指定任何参数,它会默认保存所有变量。如果只想保存一部分变量,在创建tf.train.Saver实例时,指定要保存的variables/collections,把它放到一个列表或者根据名字映射变量字典。
import tensorflow as tf
import os

w1 = tf.Variable(tf.random_normal(shape[2], name='w1'))
w2 = tf.Variable(tf.random_normal(shape[5], name='w2'))

MODEL_SAVE_DIR = /path/to/model_save_dir/
MODEL_NAME = model_name

if not os.path.exists(MODEL_SAVE_PATH):
    os.makedirs(MODEL_SAVE_DIR)

saver = tf.train.Saver([w1, w2])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME))
  • 如果你想保存最近100个模型(可按需设置,tf.train.Saver默认保存最近的5个模型)或者每训练两个小时保存一次,需要在tf.train.Saver类中设置参数max_to_keep或keep_checkpoint_every_n_hours。
saver = tf.train.Saver(max_to_keep=100)

saver = tf.train.Saver(keep_checkpoint_every_n_hours=2)

假设我们的模型设置训练10000步,每1000步保存一次模型,下图是模型保存的结果。

在这里插入图片描述


3. 加载TensorFlow模型

3.1 分步加载模型

TensorFlow模型的计算图和变量数据是分开保存的,加载模型时可以先加载计算图,再加载图中的参数,分两步实现:

  • 第一步,加载计算图

当然,你也可以重新创建计算图。在这里,我们通过模型保存的.meta文件,使用tf.train.import_meta_graph( )加载计算图。

saver = tf.train.import_meta_graph('/path/to/mnist_model-1000.meta')

tf.train.import_meta_graph( )会把.meta文件中保存的计算图加载到当前网络中,接着我们需要加载已训练的各参数值。

  • 第二步,加载模型参数

我们可以通过调用tf.train.Saver()类的restore方法来加载模型参数。restore方法会根据给出的模型保存路径,自动寻找.data-00000-of-00001文件和.index文件,并自动加载。

saver.restore(sess, tf.train.latest_checkpoint('/path/to/MODEL_SAVE_DIR'))

3.2 一次性全部加载计算图和参数

saver = tf.train.Saver()

with tf.Session() as sess:
     saver.restore(sess, '/path/to/model_save_dir/model_name')

或:

saver = tf.train.Saver()

with tf.Session() as sess:
     saver.restore(sess,tf.train.latest_checkpoint('/path/to/model_save_dir'))

前一种是从model保存路径下寻找model_name,自动加载模型文件;后一种是加载model保存路径下的全部相关文件。


3.3 更加安全的加载模型(推荐使用这种方式)

更加安全一点的加载方式,先判断模型文件是否存在,若存在再加载。

with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state('/path/to/model_save_dir')
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)

以下给出了加载TensorFlow模型的一种示例。

import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape[2], name='w1'))
w2 = tf.Variable(tf.random_normal(shape[5], name='w2'))

MODEL_SAVE_DIR = /path/to/model_save_dir/
MODEL_NAME = model_name

if not os.path.exists(MODEL_SAVE_PATH):
    os.makedirs(MODEL_SAVE_DIR)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/path/to/model-name-n.meta')
    saver.restore(sess, /path/to/model_name)
    print(sess.run('w1:0'))

此时,如w1、w2这些参数值就被加载到网络,并且可以进行访问。


4. 参考文章

https://blog.csdn.net/weixin_41108334/article/details/81565733

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值