前言:tensorflow之前的版本由于是采用静态图,所以在调试的时候比较麻烦,前面一篇文章介绍了tensorflow自带的调试器,类似于python的pdb调试,本文专门讲两个函数,用来打印tensorflow中的变量,前一篇文章参考:tensorflow Debugger教程(一)——使用自带的tfdbg进行调试
一、tensorflow查看tensor的传统做法
比如下面的例子:
import tensorflow as tf
x=tf.constant([2,3,4,5])
y=tf.constant([20,30,40,50])
z=tf.add(x,y)
with tf.Session() as sess:
x_,y_,z_=sess.run([x,y,z])
print(f"x:{x_} ,y:{y_} ,z:{z_}")
'''
x:[2 3 4 5] ,y:[20 30 40 50] ,z:[22 33 44 55]
'''
二、使用tf.Print()函数来查看
如果要实现上面同样的功能,需要这么做
import tensorflow as tf
x=tf.constant([2,3,4,5])
y=tf.constant([20,30,40,50])
z=tf.add(x,y)
xyz=tf.Print(x,["x:",x,"y:",y,'z:', z],message='x+y=z: ',summarize=100)
with tf.Session() as sess:
sess.run(xyz)
'''
x+y=z: [x:][2 3 4 5][y:][20 30 40 50][z:][22 33 44 55]
'''
2.1 tf.Print()函数定义的查看
tf.Print(input, data, message=None, first_n=None, summarize=None, name=None)
'''
参数param:
input: 是一个tensor,需要打印的张量;
data:data要求是一个list,list的每一个元素是一个张量,里面包含要打印的内容;
message:是需要输出的错误信息;
first_n:指只记录前n次;
summarize:是对每个tensor只打印的条目数量,如果是None,对于每个输入tensor只打印3个元素
name:这个操作op的名字
返回值return:
返回一个 Tensor,和 input的形状一样
'''
需要注意的是,tf.Print()只是构建一个op,需要run之后才会打印。而且,上面的例子中似乎觉得第一个参数x是没有什么用的,那是因为我们没有用到它的返回值,下面看一下tf.Print()的返回值是什么:
import tensorflow as tf
x=tf.constant([2,3,4,5])
y=tf.constant([20,30,40,50])
z=tf.add(x,y)
xyz=tf.Print(x+100,["x:",x,"y:",y,'z:', z],message='x+y=z: ',summarize=100) # 第一个参数为 x+100
with tf.Session() as sess:
print(sess.run(xyz))
'''
x+y=z: [x:][2 3 4 5][y:][20 30 40 50][z:][22 33 44 55]
[102 103 104 105] # 这其实就是tf.Print()的返回值,它实际上就是对x进行某种操作后返回,当然也可以不进行任何操作,直接返回x
'''
注意事项:
tf.Print()函数已经逐渐被遗弃,不再推荐使用,现在在较新的tensorflow版本中使用该方法会得到如下提示:
Print (from tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2018-08-20.
Instructions for updating:
Use tf.print instead of tf.Print. ......后面的省略
推荐我们使用tf.print()函数,这是下面的内容。
三、tf.print()函数的使用
前面不管是直接使用python的print函数,还是使用tf.Print都需要构建一个会话对象session,然后再会话里面运行,而tf.print()则不在需要了,但是需要用到eager_execution模式。
3.1 tf.print()函数的定义
def print_v2(*inputs, output_stream="stderr", name=None):
"""
Args:
*inputs: 可以是一个tensor或者是一个string.也可以同时输出多个tensor
output_stream: "stdout", "stderr", "log(info)", "log(warning)", "log(error)"`. Defaults to "stderr".
name: A name for the operation (optional).
"""
3.2 查看一个tensor
import sys
import tensorflow as tf
# tf.compat.v1.enable_eager_execution() # 此处使用的是tensorflow 1.13.1 版本
tf.enable_eager_execution() # 此处使用的是tensorflow 1.12.0 版本
x=tf.constant([2,3,4,5])
tf.print(x, output_stream=sys.stderr)
'''
[2 3 4 5]
'''
3.2 同时查看多个tensor
import sys
import tensorflow as tf
# tf.compat.v1.enable_eager_execution() # 此处使用的是tensorflow 1.13.1 版本
tf.enable_eager_execution() # 此处使用的是tensorflow 1.12.0 版本
x=tf.constant([2,3,4,5])
y=tf.constant([20,30,40,50])
z=tf.add(x,y)
tf.print("x:",x, "y:",y,"z:",z,output_stream=sys.stderr)
'''
x: [2 3 4 5] y: [20 30 40 50] z: [22 33 44 55]
'''
3.3 在函数中使用tf.print()
import sys
import tensorflow as tf
# tf.compat.v1.enable_eager_execution() # 此处使用的是tensorflow 1.13.1 版本
tf.enable_eager_execution() # 此处使用的是tensorflow 1.12.0 版本
x=tf.constant([2,3,4,5])
y=tf.constant([20,30,40,50])
@tf.contrib.eager.defun # 构建一个函数
def add_xy(x,y):
z=tf.add(x,y)
tf.print("x:",x, "y:",y,"z:",z,output_stream=sys.stderr)
add_xy(x,y) # 直接调用函数,不再需要构建会话
'''
x: [2 3 4 5] y: [20 30 40 50] z: [22 33 44 55]
'''
3.4 在会话session里面通过run来打印tensor
import sys
import tensorflow as tf
# tf.compat.v1.enable_eager_execution() # 此处使用的是tensorflow 1.13.1 版本
#tf.enable_eager_execution() # 此处使用的是tensorflow 1.12.0 版本
# 由于此处演示在session中 使用,故而不再需要eager模式
x=tf.constant([2,3,4,5])
y=tf.constant([20,30,40,50])
z=tf.add(x,y)
print_op = tf.print("x:",x, "y:",y,"z:",z,output_stream=sys.stderr)
sess = tf.Session() # 构建一个会话
with sess.as_default():
sess.run(print_op)
'''
x: [2 3 4 5] y: [20 30 40 50] z: [22 33 44 55]
'''
注意:上面的做法显得有点鸡肋多余,因为这和我通过普通的print函数来查看tensor的值似乎没什么区别,的确如此,但是也提供了一种思路,那就是:
在网络的构建过程我们必须要构建session会话,能不能将会话session和tf.print()结合起来呢?参见下面
3.5 在构建graph的时候使用tf.print()
import sys
import tensorflow as tf
# tf.compat.v1.enable_eager_execution() # 此处使用的是tensorflow 1.13.1 版本
# tf.enable_eager_execution() # 此处使用的是tensorflow 1.12.0 版本
# 此处构建graph和session,不再使用eager模式
x=tf.constant([2,3,4,5])
y=tf.constant([20,30,40,50])
z=tf.add(x,y)
print_op = tf.print("x:",x, "y:",y,"z:",z,output_stream=sys.stderr)
sess = tf.Session() # 构建一个会话
with sess.as_default():
with tf.control_dependencies([print_op]):
tripled_tensor = z * 3
sess.run(tripled_tensor)
在上面我想要查看的是x,y,z这三个tensor,但是我不直接去打印它们,构建了一个操作tripled_tensor,它依赖于z,而z又依赖于x,y,所以在运行tripled_tensor的时候,会将x,y,z也运行,这个时候就可以间接得到tf.print()里面的tensor值,不需要显式的去run。