TensorFlow 中 Tensor 间的运算规则
- 相同 shape Tensor 间的所有算术运算都是 element-wise 的
- 不同 shape (但 dim 0 相同) Tensor 间的运算称为广播 (broadcasting), Tensor 与 Scalar 间的运算就是其中一种
- 在运算时要求各个 Tensor 的数据类型相同
一般的广播是很好理解的,但是有一种特殊情况乍一看会觉得是有问题的,其实又是可行的,就是当运算的两个 Tensor 中之一的最后一维为 1 的时候,例如下面这段代码:
import tensorflow as tf
x = tf.constant([[[1], [2], [3]],
[[4], [5], [6]],
[[7], [8], [9]]], dtype=tf.int64)
y = tf.constant([1, 2, 3, 4], dtype=tf.int64)
ans = x - y
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(ans))
x.shape = (3, 3, 1),y.shape = (4, ),按照一般的理解,ans 的运算是不能进行的,但是实际上是可以的,ans.shape = (3, 3, 4)
ans = [[[ 0 -1 -2 -3]
[ 1 0 -1 -2]
[ 2 1 0 -1]]
[[ 3 2 1 0]
[ 4 3 2 1]
[ 5 4 3 2]]
[[ 6 5 4 3]
[ 7 6 5 4]
[ 8 7 6 5]]]
结果是将 x 最后一维元素作为一个 Scalar 进行广播,这个 trick 有时还是挺有用的,可以在使用 tf.expand_dims(input, axis=-1) 对 input 进行增维后应用
再比如
import tensorflow as tf
x = tf.constant([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 0, 0, 0]], dtype=tf.int64)
weights = tf.constant([-1, -2, -3], dtype=tf.int64)
weights = tf.expand_dims(weights, axis=-1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ans = sess.run(weights*x)
print(ans)
这里weights中的每个值代表一个作用在x对应行的权重,结果是:
[[ -1 -2 -3 -4]
[-10 -12 -14 -16]
[-27 0 0 0]]