tensorflow生成常量_TensorFlow入门教程(五):变量的创建、初始化、保存和加载

本文介绍了如何在TensorFlow中创建、初始化、保存和加载模型变量。通过tf.Variable创建变量,并使用tf.train.Saver进行保存和恢复。详细讲解了变量的初始化、自定义初始化以及如何选择性地保存和恢复模型的部分变量。
摘要由CSDN通过智能技术生成

本教程由深度学习中文社区(Studydl.com)持续发布与更新, 本系列其余教程地址见文章末尾.

在上期的文章中, 我们使用了卷积神经网络模型去识别mnist数据集, 进一步提高正确率, 达到了99.2%的准确率。虽然不是最高,但是还是比较让人满意。而当我们训练完一个神经网络后, 我们常常需要保存网络参数, 以供后来继续使用, 所以我们今天给大家介绍下变量的创建、初始化、保存和加载. 在后面的时间里我们会推出一系列的TensorFlow与PyTorch的入门教程, 希望大家多多转发与关注。

当训练模型时,用变量来存储和更新参数。变量包含张量 (Tensor)存放于内存的缓存区。建模时它们需要被明确地初始化,模型训练后它们必须被存储到磁盘。这些变量的值可在之后模型训练和分析是被加载。

本文章主要描述以下两个TensorFlow类。

tf.Variable 类tf.train.Saver类

创建

当创建一个变量时,你将一个张量作为初始值传入构造函数Variable()。TensorFlow提供了一系列操作符来初始化张量,初始值是常量或是随机值。

注意,所有这些操作符都需要你指定张量的shape。那个形状自动成为变量的shape。变量的shape通常是固定的,但TensorFlow提供了高级的机制来重新调整其行列数。

# 创建两个variables.weights = tf.Variable(tf.random_normal([784, 200],__stddev=0.35), name="weights")biases = tf.Variable(tf.zeros([200]), name="biases")

调用tf.Variable() 添加一些操作(Op, operation)到graph:

一个Variable操作存放变量的值。一个初始化op将变量设置为初始值。这事实上是一个tf.assign操作.初始值的操作,例如示例中对biases变量的zeros操作也被加入了graph。

tf.Variable的返回值是Python的tf.Variable类的一个实例。

初始化

变量的初始化必须在模型的其它操作运行之前先明确地完成。最简单的方法就是添加一个给所有变量初始化的操作,并在使用模型之前首先运行那个操作。

你或者可以从检查点文件中重新获取变量值,详见下文。

使用tf.initialize_all_variables()添加一个操作对变量做初始化。记得在完全构建好模型并加载之后再运行那个操作。

# 创建两个variables.weights = tf.Variable(tf.random_normal([784, 200],__stddev=0.35),name="weights")biases = tf.Variable(tf.zeros([200]), name="biases")# 添加一个初始化的操作init_op = tf.global_variables_initializer()# 然后,当加载会话时with tf.Session() as sess:__# 运行初始化操作__sess.run(init_op)__# 然后使用模型

由另一个变量初始化

你有时候会需要用另一个变量的初始化值给当前变量初始化。由于tf.global_variables_initializer()是并行地初始化所有变量,所以在有这种需求的情况下需要小心。

用其它变量的值初始化一个新的变量时,使用其它变量的initialized_value()属性。你可以直接把已初始化的值作为新变量的初始值,或者把它当做tensor计算得到一个值赋予新变量。

# 创建一个变量并赋予随机值weights = tf.Variable(tf.random_normal([784, 200],__stddev=0.35), name="weights")# 创建另外一个变量w2, 值和weights一样w2 = tf.Variable(weights.initialized_value(), name="w2")# 创建另外一个变量w_twice, 值是weights的两倍w_twice = tf.Variable(weights.initialized_value() * 0.2, name="w_twice")

自定义初始化

tf.global_variables_initializer() 函数便捷地添加一个op来初始化模型的所有变量。你也可以给它传入一组变量进行初始化。详情请见Variables文档,包括检查变量是否被初始化。

保存和加载

最简单的保存和恢复模型的方法是使用tf.train.Saver对象。构造器给graph的所有变量,或是定义在列表里的变量,添加save和restoreops。saver对象提供了方法来运行这些ops,定义检查点文件的读写路径。

检查点文件

变量存储在二进制文件里,主要包含从变量名到tensor值的映射关系。当你创建一个Saver对象时,你可以选择性地为检查点文件中的变量挑选变量名。默认情况下,将每个变量Variable.name属性的值。

保存变量

用tf.train.Saver()创建一个Saver来管理模型中的所有变量。

# 创建两个变量v1 = tf.Variable(..., name="v1")v2 = tf.Variable(..., name="v2")...# 添加一个初始化的操作init_op = tf.global_variables_initializer()# 添加一个op去保存和恢复变量saver = tf.train.Saver()# 然后,加载模型,初始化变量,做一些其他工作后将变量保存到磁盘.with tf.Session() as sess:__sess.run(init_op)__# 用模型做一些工作__...__# 保存变量到磁盘.__save_path = saver.save(sess, "/tmp/model.ckpt")__print"Model saved in file: ", save_path

恢复变量

用同一个Saver对象来恢复变量。注意,当你从文件中恢复变量时,不需要事先对它们做初始化。

# 创建两个变量v1 = tf.Variable(..., name="v1")v2 = tf.Variable(..., name="v2") ...# 添加一个初始化的操作init_op = tf.global_variables_initializer()# 添加一个op去保存和恢复变量saver = tf.train.Saver()# 然后,加载模型, 从磁盘恢复变量, 然后用模型做一些其他工作with tf.Session() as sess:__# 从磁盘恢复变量.__saver.restore(sess, "/tmp/model.ckpt")__print"Model restored."__# 用模型做一些其他工作__...

选择存储和恢复哪些变量

如果你不给tf.train.Saver() 传入任何参数,那么saver将处理graph中的所有变量。其中每一个变量都以变量创建时传入的名称被保存。

有时候在检查点文件中明确定义变量的名称很有用。举个例子,你也许已经训练得到了一个模型,其中有个变量命名为"weights",你想把它的值恢复到一个新的变量"params"中。

有时候仅保存和恢复模型的一部分变量很有用。再举个例子,你也许训练得到了一个5层神经网络,现在想训练一个6层的新模型,可以将之前5层模型的参数导入到新模型的前5层中。

你可以通过给tf.train.Saver()构造函数传入Python字典,很容易地定义需要保持的变量及对应名称:键对应使用的名称,值对应被管理的变量。

注意:

如果需要保存和恢复模型变量的不同子集,可以创建任意多个saver对象。同一个变量可被列入多个saver对象中,只有当saver的restore() 函数被运行时,它的值才会发生改变。如果你仅在session开始时恢复模型变量的一个子集,你需要对剩下的变量执行初始化op。详情请见tf.global_variables_initializer()。

# 创建两个变量v1 = tf.Variable(..., name="v1")v2 = tf.Variable(..., name="v2")...# 添加一个只恢复和保存变量v2的操作, 并且保存名为my_v2saver = tf.train.Saver({"my_v2": v2})# 然后就像正常情况下一样使用saver对象去保存和恢复变量....

本系列教程往期文章地址:

TensorFlow入门教程(四): 使用CNN提高MNIST数据集识别正确率

TensorFlow入门教程(三): 使用softmax回归模型识别MNIST数据集

TensorFlow入门教程(二): 构建神经网络分类器,对鸢尾花进行分类

TensorFlow入门教程(一): 变量,图与会话的基本用法

从零开始搭建深度学习服务器:TensorFlow + PyTorch + Torch

287a038962c97beabe77b43577b72c25.png

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值