- 既然说 paritially known,那就想办法直接告诉后台shape到底什么样
- tf.shape 就可以获取当前张量的shape
- 样例代码
class Sampling(layers.Layer):
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0] # 获取第一维
dim = tf.shape(z_mean)[1] # 获取第二维
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon