tensorflow在函数中用tf.Print输出中间值的方法

tensorflow由于其基于静态图的模式,导致写代码的时候很难调试,除了用官方的调试工具外,最直接的方法就是把中间结果输出出来查看,然而,直接用print函数只能输出tensor变量的形状,而不是数值,想要输出tensor的具体数值需要用tf.Print函数。网上有很多关于这个函数使用方法的说明,这里简要介绍:

Print(
    input_,
    data,
    message=None,
    first_n=None,
    summarize=None,
    name=None
	)

参数:

  • input_:通过这个操作的张量。 (流入的数据流)
  • data:计算 op 时要打印的张量列表。(用[ ]引起来的一串需要打印的东西,用逗号隔开)
  • message:一个字符串,错误消息的前缀。
  • first_n:只记录 first_n 次数。负数日志,这是默认的。
  • summarize:只打印每个张量的固定数目的条目。如果没有,则每个输入张量最多打印3个元素。
  • name:操作的名称(可选)

然而网上大部分资源都是介绍如何在主函数中先建立一个op,再开启一个Session执行sess.run(op)的方法,但是如果想要输出函数中的中间值而该值又未传回主函数呢?这种情况下无法在函数中开启一个新的Session,但是仍然可以用tf.Print建立op来实现。

import tensorflow as tf
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def test():
    a=tf.constant(0)
    for i in range(10):  
        a_print = tf.Print(a,['a_value: ',a])
        a=a_print+1
    return a
    
if __name__=='__main__':
    with tf.Session() as sess:
        sess.run(test())

运行结果:
在这里插入图片描述
a_print可以理解为在图中新增了一个节点,在后续代码中当有别的变量使用了a_print时(如上例a=a_print+1),就会有数据从a_print节点上流过,就会输出值,而究竟会输出几次值呢?这其实并不是看下文中a_print被使用了几次,而是看数据流要从该节点上流经几次,可以理解为a_print这个op被“定义”了几次。

def test():
    a=tf.constant(0)
    a_print = tf.Print(a,['a_value: ',a])
    for i in range(10):  
        a=a_print+1
    return a
    
if __name__=='__main__':
    with tf.Session() as sess:
        sess.run(test())

如果把test()函数改成这样,则运行结果为:
在这里插入图片描述
输出仅被执行了一次,因为a_print这个op只被定义了一次,虽然后面在循环里不断被a使用,但是数据只从它身上经过了一次,所以只会print一次,并且a_print的值永远为0,最终返回的a的值也为1。
再把代码改成下例:

def test():
    a=tf.constant(0)
    a_print = tf.Print(a,['a_value: ',a])
    for i in range(10):  
        a_print=a_print+1
    return a
    
if __name__=='__main__':
    with tf.Session() as sess:
        sess.run(test())

运行结果是什么也不会输出,因为a_print这个op没有和别的变量发生关系,它没有被别的变量使用,在图里为孤立的一个节点,没有数据流过,就不会被执行。
而如果改成这样

def test():
    a=tf.constant(0)
    a_print = tf.Print(a,['a_value: ',a])
    for i in range(10):  
        a_print=a_print+1
    return a_print
    
if __name__=='__main__':
    with tf.Session() as sess:
        sess.run(test())

运行结果
在这里插入图片描述
返回的a_print值为10也是正确的,因为a_print在下文被返回,所以有数据流流经,会被执行,而因为a_print的定义只执行一次,所以只会输出一次。

  • 18
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值