Tensorflow学习笔记:CNN篇(8)——Finetuning,模型更为细化的保存与恢复

Tensorflow学习笔记:CNN篇(8)——Finetuning,模型更为细化的保存与恢复


前序

— 对于模型的保存和恢复,前文已经做了介绍,然而读者可能已经注意到,在设定的保存文件夹中有着4个不同的文件类型:
这里写图片描述
可以得知,根据需要每个文件类型都有其不同的用处,但是仅仅知道这些还不够,对于Tensorflow工作人员来说,需要更进一步了解不同文件所处的作用。


存储文件的解读

— 在介绍存储文件之前,先对Saver类进行一下解释。在不同的会话中,当需要将数据在硬盘上进行保存时,就可以使用Saver类。这个Saver构造类允许你去控制3个元素:

  • 目标(The target):设定目标。在分布式架构的情况下,我们可以指定要计算哪个Tensorflow服务器或者“目标”
  • 图(The graph):设置保存的图。保存希望会话处理的图。对于初学者来说,这里有一个棘手的事情就是在Tensorflow中总有一个默认的图,并且所有的操作都是在这个图中首先进行,所以总是在“默认图范围”内。
  • 配置(The config):设置配置。可以使用ConfigProto参数来配置Tensorflow。

—Saver类可以处理图中元数据和变量数据的保存和恢复,而我们唯一需要做的是,告诉Saver类需要保存哪个图和哪些变量。在默认的情况下,Saver类能处理默认图中包含的所有变量。但是,我们也可以创建出很多的Saver类,去保存想要的任何子图。
—介绍完Saver类,对于模型存储来说,这里有4个文件类型,依次如下:

  • checkpoint:检查点文件,记录存储文件名称。
  • save_model.ckpt.data-00000-of-00001:等价于save_model.ckpt,权重存储文件
  • save_model.ckpt.index:存储权重目录
  • save_model.ckpt.meta:模型的全部图文件

在对模型进行保存和恢复时,Saver类将保存于图像关联的任何元数据,这意味着加载元检查点还将恢复与图相关联的所有空变量、操作和集合。


代码示例

现在抛开理论介绍而对模型进行恢复与处理。,由于Tensorflow将整体的“图”文件存储在meta后缀的文件中,而将权重存储在ckpt后缀的文件中,在其具体使用时,对于模型权重的注入则是根据相应的名称来进行,因此,如果需要对模型中不同的权重进行重新注入的话,那么第一步就是需要赋予不同的权重以名称。

with tf.variable_scope("var"):
        self.a_val = tf.Variable(tf.random_normal([1]),name="a_val")
        self.b_val = tf.Variable(tf.random_normal([1]),name="b_val")

这里首先使用了tf.variable_scope对域进行了定义,之后在定义域内对输入变量进行赋值。最终形成的名称为:

var/a_val
Step 1: 重新定义的线性回归类

首先是对于线性回归类的定义,在前面已经说了,需要对不同的变量或者占位符以及不同的函数定义其在图中的名称,这里为了简便,只定义了变量和占位符的名称:

import tensorflow as tf

class LineRegModel:

    def __init__(self):
        with tf.variable_scope("var"):
            self.a_val = tf.Variable(tf.random_normal([1]),name="a_val")
            self.b_val = tf.Variable(tf.random_normal([1]),name="b_val")
        self.x_input = tf.placeholder(tf.float32,name="input_placeholder")
        self.y_label = tf.placeholder(tf.float32,name="result_placeholder")
        self.y_output = tf.add(tf.multiply(self.x_input, self.a_val), self.b_val,name="output")
        self.loss = tf.reduce_mean(tf.pow(self.y_output - self.y_label, 2))

    def get_saver(self):
        return tf.train.Saver()

    def get_op(self):
        return tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)

在程序中可以看到,这里对每个变量或占位符都设置了相应的名称,而对变量域又设置了对应的域名。

Step 2: 重新对模型进行训练
import tensorflow as tf
import numpy as np
import global_variable
import lineRegulation_model as model

train_x = np.random.rand(5)
train_y = 5 * train_x + 3.2   # y = 5 * x + 3
model = model.LineRegModel()

a_val = model.a_val
b_val = model.b_val

x_input = model.x_input
y_label = model.y_label

y_output = model.y_output

loss = model.loss
optimize = model.get_op()
saver = model.get_saver()

if __name__ == "__main__":
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    flag = True
    epoch = 0
    while flag:
        epoch += 1
        _ , loss_val = sess.run([optimize,loss],feed_dict={x_input:train_x,y_label:train_y})
        if loss_val < 1e-6:
            flag = False
    print(a_val.eval(sess) , "   ", b_val.eval(sess))
    print("-----------%d-----------"%epoch)
    print(a_val.op)
    saver.save(sess,global_variable.save_path)
    print("model save finished")
    sess.close()

这里写图片描述
可以看到,其中的节点名被定义为“var/a_val”,这是类中被定义是赋予的变量名称。

Step 3: 模型的恢复

对于模型的恢复来说,需要首先恢复模型的整个图文件,之后从图文件中读取相应的节点信息。

saver = tf.train.import_meta_graph('./model/save_model.ckpt.meta')

Saver方法先从图中获取了整个图的信息,之后根据节点名称将不同的变量或者占位符重新按名称赋值。

#读取placeholder和最终的输出结果
graph = tf.get_default_graph()
a_val = graph.get_tensor_by_name('var/a_val:0')

input_placeholder=graph.get_tensor_by_name('input_placeholder:0')
labels_placeholder=graph.get_tensor_by_name('result_placeholder:0')
y_output=graph.get_tensor_by_name('output:0')#最终输出结果的tensor

而具体的权重恢复则需要在对话中完成。

with tf.Session() as sess:
    saver.restore(sess, './model/save_model.ckpt')

完整代码

import tensorflow as tf

saver = tf.train.import_meta_graph('./model/save_model.ckpt.meta')

#读取placeholder和最终的输出结果
graph = tf.get_default_graph()
a_val = graph.get_tensor_by_name('var/a_val:0')

input_placeholder=graph.get_tensor_by_name('input_placeholder:0')
labels_placeholder=graph.get_tensor_by_name('result_placeholder:0')
y_output=graph.get_tensor_by_name('output:0')

with tf.Session() as sess:
    saver.restore(sess, './model/save_model.ckpt')
    result = sess.run(y_output, feed_dict={input_placeholder: [1]})
    print(result)
    print(sess.run(a_val))

读者可能注意到,在程序中采用通过名称获取对应的变量值的时候,冒号的右边有一个0符号,这是在Tensorflow的图运行中为了进行参数的复用而使用的标记类型,这里读者可以对其忽略而直接使用,程序运行的结果如下:
这里写图片描述

Step 4: 恢复模型的特定值

如果要对模型的特定值进行恢复,同样可以使用这个首先载入图文件之后使用权重对其赋值的办法。

import tensorflow as tf

saver = tf.train.import_meta_graph('./model/save_model.ckpt.meta')

graph = tf.get_default_graph()
a_val = graph.get_tensor_by_name('var/a_val:0')

y_output=graph.get_tensor_by_name('output:0')

with tf.Session() as sess:
    saver.restore(sess, './model/save_model.ckpt')
    print(sess.run(a_val))

可以看到这里只定义了变量a_val,并通过相应的名称将其重新获取。这种方法可以获取到模型中特定的变量或者节点的值,其最终结果如下:
这里写图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值