需要一个直接判断两tensor的值是否完全相等的功能。目前(TensorFlow Core v2.8.0),TensorFlow 2里有 a==b或 tf.math.equal()这样element-wise的函数来判断相等,但并使用不方便。
因此写一个简单实现的函数 tensor_equal(a, b),若相等则返回True。
# 判断两个tensor的值是否相等
def tensor_equal(a, b):
# 判断类型是否均为tensor
if type(a) != type(b):
return False
if isinstance(a, type(tf.constant([]))) is not True:
if isinstance(a, type(tf.Variable([]))) is not True:
return False
# 判断形状相等
if a.shape != b.shape:
return False
# 逐值对比后若有False则不相等
if not tf.reduce_min(tf.cast(a == b, dtype=tf.int32)):
return False
return True
原理是element-wise比较后,取最小值若非False,则相等。
# 进行测试
if __name__ == '__main__':
x = tf.constant([0, 1])
ans = tensor_equal(x, tf.constant([0, 1]))
print(ans)
ans = tensor_equal(x, tf.constant([0, 1, 3]))
print(ans)
ans = tensor_equal(x, [0, 1])
print(ans)
ans = tensor_equal(x, tf.Variable([0, 1]))
print(ans)
返回结果:
True
False
False
False