TensorFlow中的tf.test.TestCase类继承了unittest.TestCase类,用于对tensorflow代码进行单元测试。
tf.test.TestCase 提供了 assertAllEqual 用于判断两个numpy array具有完全相同的值,session方法来运行计算图结点,以及其他方法,具体请看链接。
现在我们有如下的两个函数:
# Python3
import tensorflow as tf
def dense_layer(x, W, bias, activation=None):
y = x @ W + bias
if activation:
return activation(y)
else:
return y
def expand_reshape_tensor(x, high, width):
return tf.reshape(x, (high, width, 1, 1))
第一个函数就是一个全连接层,第二个函数用于对张量进行扩张塑形操作。
接下来我们创建UtilsTests类,继承tf.test.TestCase类,定义test_dense_layer方法对第一个函数进行测试,定义test_expand_reshape_tensor方法对第二个函数进行测试。
import tensorflow as tf
import utils
class UtilsTests(tf.test.TestCase):
def test_dense_layer(self