自定义神经网络时的注意事项

文章讲述了在使用tf.keras.Model自定义神经网络时,遇到因输入形状不匹配导致的ValueError,特别是当处理共享预处理层和不同输入形状时。作者提供了通过定义独立的卷积层和flatten层来解决这个问题的方法。
摘要由CSDN通过智能技术生成

问题描述

`

通过继承tf.keras.Model自定义神经网络模型时遇到的一系列问题。

代码如下,

class STFT_ConV2D(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pre_layer = tf.keras.Sequential([
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(768, activation='relu')
        ])

        self.add = tf.keras.layers.Add()
        self.output_dense = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, inputs):
        x, y = inputs
        x = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_x)(x)
        x = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_x)(x)
        x = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_x)(x)
        x = self.pre_layer(x)

        y = tf.keras.layers.Conv2D(filters=3, kernel_size=8, input_shape=Input_shape_y)(y)
        y = tf.keras.layers.Conv2D(filters=3, kernel_size=16, input_shape=Input_shape_y)(y)
        y = tf.keras.layers.Conv2D(filters=1, kernel_size=32, input_shape=Input_shape_y)(y)
        y = self.pre_layer(y)
        output = self.add([x, y])
        output = self.output_dense(output)
        return output

产生的bug为,

  ValueError: Exception encountered when calling layer 'sequential' (type Sequential).
        
  Input 0 of layer "dense" is incompatible with the layer: expected axis -1 of input shape to have value 11368, but received input with shape (None, 210680)

x输入和y输入都使用了成员变量pre_layer,共享了pre_layer层,也就共享了pre_layer层的参数矩阵和结构。
由于x先经过三层卷积层后shape由原来的shape=(360, 256, 109, 1)变成了shape=(360, 203, 56, 1)
再经过pre_layer层里的Flatten时,除“ batchsize ”轴(axis=0)外,其余轴被铺平,输出shape=(360,11368)。接着处理y输入,经过三层卷积层后,shape由原来的shape=(360, 511, 513, 1)变成了shape=(360,458, 460, 1),之后执行到y = self.pre_layer(y)时,如果执行成功,则输出shape=(360,21068),此时与x的shape=(360,11368)维度冲突,从而产生异常。

要点归纳:

  1. 通过继承tf.keras.Model写神经网络模型时,每一个神经网络层只能被同一个输入占有。
  2. 所有tf.keras.layers下的层对象不能直接出现在call()方法中,必须以成员变量的形式在构造器中定义,然后在call()方法中通过self.成员变量的方式调用
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值