- 概述
Tensorflow中,test类用来做单元测试,它继承于unittest.TestCase类,里面包含了Tensorflow做单元测试相关的方法。
- Tensorflow之test类的相关类Benchmark和TestCase
- class Benchmark:Abstract class that provides helpers for TensorFlow benchmarks.
- class TestCase:Base class for tests that need to test TensorFlow.
- Tensorflow之test类的方法。
- 示例:测试tensor的值是否正确
这是一个基本的test类测试案例,基于Tensorflow2.3版本,代码有所修改。
import tensorflow as tf class SquareTest(tf.test.TestCase): def testSquare(self): with self.session(): #平方操作 x = tf.square([2, 3]) # 测试x的值是否等于[4,9] self.assertAllEqual(x.numpy(), [4, 9]) if __name__ == "__main__": tf.test.main()
程序从入口运行,tf.test.main()运行所有的单元测试,通过self.assertAllEqual(x.numpy(), [4, 9])判断x的值是否等于[4,9],运行结果如下:
......
Ran 2 tests in 0.059s
OK
如果把代码改为:
.....
# 测试x的值是否等于[3,9]
self.assertAllEqual(x.eval(), [3, 9])
.....
运行后,测试失败,结果如下:
AssertionError:
Arrays are not equal
(mismatch 50.0%)
x: array([4, 9], dtype=int32)
y: array([3, 9])
--------------------------------------------------------------------
Ran 2 tests in 0.028s
FAILED (failures=1)