Tensorflow2 Warning incompatible shape()

最近在做实验,设计了一个中间层,有两个输入,一个是前面编码器的输出encoder_output,形状为(None,z_dim)的tensor,另外一个是gumbel_softmax的温度temp,形状为(None,1)的tensor。

def make_gumbel_layer(n_class):
    encoder_output = tf.keras.Input(shape=(z_dim,))
    temp = tf.keras.Input(shape=(1,))
    x = tf.keras.layers.Dense(n_class)(encoder_output)
    gumbel_label = gumbel_softmax(x, temp)
    model = tf.keras.Model(inputs=[encoder_output, temp], outputs=gumbel_label)
    return model

 调用的代码为

    with tf.GradientTape() as ae_tape:
        encoder_output = encoder(batch_x, training=True)
        gumbel_label = gumbel_layer([encoder_output, temp])
        decoder_output = decoder([encoder_output, gumbel_label], training=True)

结果报warning。

尝试在调用的代码中输出了temp的数据类型

    with tf.GradientTape() as ae_tape:
        encoder_output = encoder(batch_x, training=True)
        print(type(temp))
        gumbel_label = gumbel_layer([encoder_output, temp])
        decoder_output = decoder([encoder_output, gumbel_label], training=True)

输出结果表明temp是一个float object,不是tensor。可以想到,temp被自动转换为形如()的tensor,即tensor的标量。而模型定义输入是(None, 1)。所以报warning。

所以只需要把temp 转换为 形如(1,1)的tensor即可。将调用的代码改为

    with tf.GradientTape() as ae_tape:
        encoder_output = encoder(batch_x, training=True)
        gumbel_label = gumbel_layer([encoder_output, tf.reshape(temp,[1,1])])
        decoder_output = decoder([encoder_output, gumbel_label], training=True)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值