静态图谱

计算图谱

自动微分是深度学习建模工具必备的一个技能点。我们将损失函数看作是若干基本函数的 复合函数,这里基本函数包括加减乘除、sin、cos、指数函数、幂函数、对数函数等。我们使用链式法则进行复合函数的导数,比如df(g)=df|g * dg。如果将基本函数看作节点,我们可以将复合函数使用图进行表示,求导的链式法则可以方便的在这个图上进行。这就是计算图谱提出的背景。

静态图与动态图

tensorflow1.x 使用的是静态图,先创建计算图谱后执行计算。好处就是图结构固定不变,方便优化,运行效率高。缺点就是,当修改图结构时,需要重新运行整张图才能看到修改的结果,调试不方便。所以静态图适合部署应用。

tensorflow2.x和pytorch使用的是动态图,执行计算时创建计算图谱。好处就是调试方便,缺点是运行效率不如静态图高。所以动态图比较适合科研,可以迅速修改模型进行验证。

tf.Session()

使用tf1.x进行计算的步骤是:

1、 创建计算图谱

2、 创建会话

3、 在会话中初始化计算图谱中的参数

4、 在会话中执行run,获取目标节点的输出结果

指定session的计算图谱的方式可以是:

1、 tf.Session(graph=YourGraph)

2、



with YourGraph.as_default():

	with tf.Session() as sess:

       		sess.run()

不明确指定graph的话,session会使用默认的graph。获取默认graph的命令是 tf.get_default_graph()。

tf.Session().run()

签名如下:

run(

fetches, feed_dict=None, options=None, run_metadata=None

)

官方说明:

在这里插入图片描述

注意:我们需要在feed_dict中明确fetches中的目标节点依赖的全部输入,否则报错。

静态图保存

我们使用tf.train.Saver()保存静态图。

命令如下:

with self.graph.as_default():   
	with tf.Session() as sess:       
		saver=tf.train.Saver()
              	#blablabla
              	saver.save(sess,save_path)

注意save()要在sess的上下文中调用。

保存生成的静态图文件如下:

在这里插入图片描述

meta文件保存了静态图的网络结构,index和data保存参数。注意参数的保存形式是key-value形式。

静态图重载

需要重载的场景通常有:

1、 使用训练好的模型进行预测,即部署应用。此时我们需要重新加载包含模型参数的静态图。

2、 训练过程中,需要在之前训练的基础上继续训练模型,此时需要重新加载之前训练得到的静态图。

静态图的保存是序列化过程,将内存中的数据结构保存成字符串。而加载就是反序列化过程,将硬盘上的字符串转成内存中的数据结构。举个例子,假设静态图中的算子节点的类定义如下:

class Node:

       def __init__(self,name,val1,val2):

              self.name=name

              self.val1=val1

              self.val2=val2

我们实例化一个加法操作节点,Node(‘+’,5,6), 序列化后的结果可能就是’{‘name’:’+’,’val1’:’5’, ’val2’:6}’。

加载静态图的反序列化通过命令tf.train.import_meta_graph(‘file.meta’)完成,我们读取.meta文件在当前的graph上恢复静态图结构。saver.restore(sess,‘file’)可以将保存的静态图参数加载进来。完整命令如下:

with YourGraph.as_default():

   with tf.Session() as sess:

       saver=tf.train.import_meta_graph('saved_file.meta') 

       saver.restore(sess,'saved_file')

sess.run(fetches=,feed_dict=)

大家一定注意到了,fetches和feed_dict我都空着了,因为这里需要填入tensor或者op的内存地址,而我们不知道这个内存地址。我们需要一个之前保存的序列化的节点名称和反序列化恢复的节点的内存地址的映射关系。tf.Graph().get_tensor_by_name()和tf.Graph().get_operation_by_name()就是做这个事情的。

假设fetch的tensor的名称是“target_tensor”,它依赖的输入名称是“input_tensor”,这些名称均是硬盘上保存的序列化的静态图的tensor名称,这样上面的代码就可以修改为:

with YourGraph.as_default():
    with tf.Session() as sess:
	saver=tf.train.import_meta_graph('saved_file.meta') 
	saver.restore(sess,'saved_file')

	y=self.graph.get_tensor_by_name("target_tensor
")
	x=self.graph.get_tensor_by_name(“input_tensor”)
	sess.run(fetches=y, feed_dict={x: values})

下期预告:
结合具体的建模例子说明静态图谱的加载方法。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值