认识一个类主要从它的属性和方法两个方面去了解它的使用方法,属性提供了这个类中包含的信息,方法提供了这个类能做什么的途径。Tensor作为TF底层操作的一个基础元素,熟练掌握是学好TF的关键环节。
注:此部分内容核心思想来自于官方文档介绍 https://www.tensorflow.org/api_docs/python/tf/Tensor
Tensor是什么?
官方文档中说道:
A Tensor is a symbolic handle to one of the outputs of an Operation.
Tensor是某个OP输出的句柄,TF通过这种方式,就可以将图上的OP通过Tensor连接起来。对应到计算图上,OP是图上的顶点(或者说节点),Tensor为从一个顶点到另一个顶点的连线。Tensor通常是通过tensorflow.OP()
类似的函数创建得到的。
一个Tensor中大致描述了两件事:从哪来?自身有什么特征?
- 从哪里来:依赖关系
- 自身特性:具有的性质
Tensor的属性
查看官方文档可知,Tensor中一共有7个属性,按照从哪里来和自身有什么特征来分别认识它们。
A. 从哪里来?
- graph:这个Tensor所在的计算图
- op:产生这个Tensor的那个Operation
- value_index:来自于产生这个Tensor的Operation的第几个输出
- device:这个Tensor数据是来自于哪个设备计算的结果
B. 自身有什么特征?
- name:创建这个Tensor时候赋予的名字
- dtype:该Tensor中数据的类型。 Tensor的dtype来自于DType类:https://www.tensorflow.org/api_docs/python/tf/DType,这里可以看到TF中所有数据类型
- shape:该Tensor中数据存储的逻辑结构,即数据(矩阵)的维数
这里值得注意的是,Tensor由于描述的也是矩阵的特点,只是这个矩阵或向量在TF中被叫做张量。如果你熟悉numpy库的使用,应该对dtype和shape两个属性不陌生,其结构以及用法都是类似的。
关于Tensor的name,如果我们在用OP创建一个Tensor的时候不给它命名,则TF环境会默认给它起名,用上面的属性来表示自动命名方式,大致为:
<op_name>:<value_index>
其中
op_name
如果重名,则会自动在后面加“下划线+1,2,3,….”的编号形式往后排名字。如:Add、Add_1、Add_2……以此类推,因此除了一些特殊用法(如共享变量),一张计算图上不存在两个名称完全相同的op_name
Tensor的方法
纵观Tensor类,可以发现有两种类型的方法:
- 系统内定的方法,以及对这些方法的(override)重写方法——以双下划线开头和结尾的方法
- 常规的公共方法:consumers、eval、get_shape、set_shape
系统内定方法
Tensor类在数学上的用法,类似于numpy库的用法。在Tensor中,我们能够看到有大量的__xxx__(self, ...)
类型的方法。学过python的小伙伴都知道,这类命名方式被称作系统内定的名称,我们查看了源码后,发现在Tensor中这样名称的方法主要分为两类,一类是python内定的,如:__init__、__iter__
等方法;另外一类是在Tensorflow环境中经过重载的运算符。
a = tf.placeholder('int32')
b = tf.placeholder('int32')
c = a + b # 不能设定c的名字
# 等价于 c = tf.add(a, b) 可以设定c的名字,加上name='xxx'参数即可
# 也等价于 c = a.__add__(b) 可以设定c的名字,加上name='xxx'参数即可
有关python运算符重载相关的知识感兴趣的话,可自行搜索相关知识,在这里只需要了解通过运算符的重载,使得我们能够直接通过符号来操作Tensor对象,比如:在计算a+b的时候,python就会自动调用a的__add__
(左加)函数或者b的__radd__
(右加)函数,从而实现诸如 c = a + b
这样的操作。
通过查看TF的python接口源码:https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/framework/ops.py
可以找到如下定义:
# List of Python operators that we allow to override.
OVERLOADABLE_OPERATORS = {
# Binary. 二元的(有一个Tensor类型的对象作为参数)
"__add__",
"__radd__",
"__sub__",
"__rsub__",
"__mul__",
"__rmul__",
"__div__",
"__rdiv__",
"__truediv__",
"__rtruediv__",
"__floordiv__",
"__rfloordiv__",
"__mod__",
"__rmod__",
"__lt__",
"__le__",
"__gt__",
"__ge__",
"__and__",
"__rand__",
"__or__",
"__ror__",
"__xor__",
"__rxor__",
"__getitem__",
"__pow__",
"__rpow__",
# Unary. 一元的(不需要其他Tensor类型的对象作为参数)
"__invert__",
"__neg__",
"__abs__",
"__matmul__",
"__rmatmul__"
}
通过上面的这些方法,就可以方便的直接使用符号来创建默认名称的Tensor。
常规的公共方法
以后介绍函数时,格式均为:
[函数返回值] 函数名称(参数1,参数2,...)
1. 获取使用当前tensor的operation集合:consumers
[list of Operations] consumers()
在Tensor的属性中,可以通过查看op和value_index或者name来获得该tensor来自于哪个operation的哪个输出(从哪来),同理,也可以通过consumers方法知道这个Tensor被用在了哪些operations中当作输入(到哪去)。
例如,我们定义一个输入a被三个operation使用:
a = tf.placeholder('int32')
b = a + 1
c = a + 2
d = a * c
print(b)
print(c)
print(d)
print(a.consumers())
运行后得到输出结果:
Tensor("add:0", dtype=int32)
Tensor("add_1:0", dtype=int32)
Tensor("mul:0", dtype=int32)
[<tf.Operation 'add' type=Add>, <tf.Operation 'add_1' type=Add>, <tf.Operation 'mul' type=Mul>]
通过上面的小实验得出以下结论:
- 在Tensor中被重载的运算符在使用时的名字是系统自动设定的
- 在打印operation时,显示的是
<tf.Operation, op_name, op_type>
- a被三个operation当作了输入,分别为add、add_1和mul
2. 求取该tensor的结果:eval
[numpy array<ndarray>] eval(feed_dict=None, session=None)
这种方式是session的另一种使用形式,用这种方式与使用
sess.run()
的形式结果是一样的,但是要注意在使用eval方法的时候,必须要有个默认的会话是开启的状态。
引用官方文档中的一句话:
Before invoking Tensor.eval(), its graph must have been launched in a session, and either a default session must be available, or session must be specified explicitly.
翻译:使用eval函数之前,这个tensor所在的图必须已经被某个session启动并被设定为默认会话,或者明确指定某个session
也就是说,能够用两种方式去执行:
# 方法1:使用sess.as_default()方法指定默认会话
sess = tf.Session()
c = tf.constant(5.0)
with sess.as_default():
print(c.eval())
sess.close()
# 方法2:使用tf.Session()和with来产生一个会话的作用范围,这种方式会自动指定默认会话
c = tf.constant(5.0)
with tf.Session():
print(c.eval())
可以看到这两种方式都没有使用sess.run
这样的函数,但是我们仍然能够通过直接访问该变量的eval函数获取计算结果。上面的两种方式,对应直接使用sess的方式的代码如下:
# 方法1的对应方法:不用with的情况
sess = tf.Session()
c = tf.constant(5.0)
print(sess.run(c))
sess.close()
#方法2对应的方法:用with的情况
c = tf.constant(5.0)
with tf.Session() as sess: # 需要将其设定为sess
print(sess.run(c))
更具体的使用方法将在Session中再叙述,这里只需要知道Tensor.eval和Session.run用法类似就好,可以将eval方法看作是session.run方法的一个快捷调用的方式,可以使代码看起来简洁又直接。
3. get_shape 和 set_shape 方法
由于TF是将图的设计与计算分开来做的,当我们对数据的维数有要求的时候,就需要给它规定一个shape,这时候,就用到了set_shape的方法,这个方法主要作用是使得这个Tensor更加的specific。比如,在定义一个占位符的时候,我们可能想限定输入数据的维数,所以有如下代码:
a = tf.placeholder('float')
print(a.get_shape()) # 结果是: <unknown>
a.set_shape([2, 3])
print(a.get_shape()) # 结果是:(2, 3)
# 若在已经规定了维数后,再次规定维数,则会报错
a.set_shape([2, 4]) # 报错:4与3不匹配
print(a.get_shape())
因此,对于Variable以及其他已经在创建时规定好shape的,是无法通过这个函数修改维数的。想修改只能通过tf.reshape(...)
操作来重新生成一个新的Tensor。
Tensor的显示
通常在编程debug的时候,会希望输出显示一下某个变量的内容,来检查程序上的问题。当使用python中的print函数时,只能显示出来Tensor的摘要信息,而无法像numpy那样打印出Tensor中保存的数据。由于TF环境是将图的定义与计算分开的,所以这就使得实际运行时出现的问题(这个问题不一定会导致TF报错)的时候,使用开发环境的debug工具是无法通过断点调试的方式像平时debug普通python程序那样去watch每一个变量的变化情况的。除了使用TF官方提供的tfdbg工具来对运行过程进行调试以外,还可以使用tf.Print(<Tensor to pass>,list<Tensor to display>)
操作来对某个Tensor中的数据进行显示。
本篇主要是介绍Tensor类,下一篇文章中详细介绍这个操作的使用方法和使用体验。
小结
本文主要从Tensor类的属性和方法两个方面介绍了Tensor类的使用以及基本作用,Tensor作为TF的核心数据类型,正是由于它内部这些数据结构的定义,使得TF可以轻松的找到Tensor与Tensor之间的依赖关系,从而知道OP之间的关系,这样就为图的计算提供了结构依据。而于此同时,运算符的重写机制与Tensor对于同名情况自动命名的形式,也使得其使用感觉很自然,如果您熟悉numpy的操作,应该很快就能学会对Tensor的操作。