判断两个tensor的值是否相等(TensorFlow 2)

        需要一个直接判断两tensor的值是否完全相等的功能。目前(TensorFlow Core v2.8.0),TensorFlow 2里有 a==b或 tf.math.equal()这样element-wise的函数来判断相等,但并使用不方便。

        因此写一个简单实现的函数 tensor_equal(a, b),若相等则返回True。

# 判断两个tensor的值是否相等
def tensor_equal(a, b):
    # 判断类型是否均为tensor
    if type(a) != type(b):
        return False
    if isinstance(a, type(tf.constant([]))) is not True:
        if isinstance(a, type(tf.Variable([]))) is not True:
            return False
    # 判断形状相等
    if a.shape != b.shape:
        return False
    # 逐值对比后若有False则不相等
    if not tf.reduce_min(tf.cast(a == b, dtype=tf.int32)):
        return False
    return True

        原理是element-wise比较后,取最小值若非False,则相等。

# 进行测试
if __name__ == '__main__':
    x = tf.constant([0, 1])
    ans = tensor_equal(x, tf.constant([0, 1]))
    print(ans)
    ans = tensor_equal(x, tf.constant([0, 1, 3]))
    print(ans)
    ans = tensor_equal(x, [0, 1])
    print(ans)
    ans = tensor_equal(x, tf.Variable([0, 1]))
    print(ans)

返回结果:

True
False
False
False

  • 3
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值