【BUG记录】layers.BatchNormalization()的使用

【BUG记录】 layers.BatchNormalization()的使用

我的目标是同tensorflow改写以下的pytorch搭建的CNN模型

#PyTorch搭建的模型
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()

        self.conv=nn.Sequential(
            # first layer
            nn.Conv2d(1,32,kernel_size=(2,5)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((1,2)),
            # second layer
            nn.Conv2d(32, 32, kernel_size=(2, 3)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),
            # second layer
            nn.Conv2d(32, 32, kernel_size=(2, 2)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
        )

        self.dense_layer=nn.Sequential(
            nn.Flatten(),
            nn.Linear(1120,1156),
            nn.BatchNorm1d(1156),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(1156, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),

            nn.Linear(256,5),
        )

    def forward(self,input):
        feature_cnn=self.conv(input)
        # feature_cnn=feature_cnn.view(-1,32*5*7)
        output=self.dense_layer(feature_cnn)

        return output

以下是实际搭建的模型

#tensorflow搭建模型
class ConvNet(Model):

    def __init__(self):
        self.filter=filter
        super(ConvNet,self).__init__()

        self.conv=keras.Sequential([
            # first layer
            layers.Conv2D(32,kernel_size=(2,5)),
            layers.BatchNormalization(32),
            layers.ReLU(),
            layers.MaxPool2D((1, 2)),
            # second layer
            layers.Conv2D(32, kernel_size=(2,3)),
            layers.BatchNormalization(32),
            layers.ReLU(),
            layers.MaxPool2D((1, 2)),
            # third layer
            layers.Conv2D(32, kernel_size=(2, 2)),
            layers.BatchNormalization(32),
            layers.ReLU(),
        ])

        self.dense_layer=keras.Sequential([
            layers.Flatten(),

            layers.Dense(1056),
            layers.BatchNormalization(1056),
            layers.ReLU(),
            layers.Dropout(rate=0.5),

            layers.Dense(512),
            layers.BatchNormalization(512),
            layers.ReLU(),

            layers.Dense(256),
            layers.BatchNormalization(256),
            layers.ReLU(),

            layers.Dense(5)
        ])

    def call(self,x,is_training=False):
        x=tf.reshape(x, [-1, 8, 40, 1])
        x=self.conv(x)
        # x=tf.reshape(x,[-1,32*5*7])
        x=self.dense_layer(x)

        if not is_training:
            x=tf.nn.softmax(x)
        return x

在实际运行之后出现以下问题

Traceback (most recent call last):
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\IPython\core\interactiveshell.py", line 3361, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-31e68b6411fb>", line 1, in <cell line: 1>
    runfile('D:/PostGraduate/Shanghai University/Research_Group/Task/EMG/self/code/EMG-left_and_right_arms - vote/main_code/cnn_model.py', wdir='D:/PostGraduate/Shanghai University/Research_Group/Task/EMG/self/code/EMG-left_and_right_arms - vote/main_code')
  File "D:\Software\Professional\Pycharm\Pycharm\PyCharm Community Edition 2021.3.3\plugins\python-ce\helpers\pydev\_pydev_bundle\pydev_umd.py", line 198, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "D:\Software\Professional\Pycharm\Pycharm\PyCharm Community Edition 2021.3.3\plugins\python-ce\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "D:/PostGraduate/Shanghai University/Research_Group/Task/EMG/self/code/EMG-left_and_right_arms - vote/main_code/cnn_model.py", line 40, in <module>
    pred=conv_net(batch_x)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1030, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "D:\PostGraduate\Shanghai University\Research_Group\Task\EMG\self\code\EMG-left_and_right_arms - vote\main_code\cnn_function.py", line 83, in call
    x=self.conv(x)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\keras\engine\base_layer.py", line 1006, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\keras\engine\sequential.py", line 389, in call
    outputs = layer(inputs, **kwargs)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\keras\engine\base_layer.py", line 1006, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\keras\engine\functional.py", line 1442, in call
    return getattr(self._module, self._method_name)(*args, **kwargs)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 1023, in __call__
    self._maybe_build(inputs)
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\tensorflow\python\keras\engine\base_layer.py", line 2625, in _maybe_build
    self.build(input_shapes)  # pylint:disable=not-callable
  File "D:\Software\Professional\Anaconda\envs\tensorflow2-gpu\lib\site-packages\tensorflow\python\keras\layers\normalization.py", line 315, in build
    raise ValueError('Invalid axis: %s' % (self.axis,))
ValueError: Invalid axis: ListWrapper([32])

最后发现错误原因出现在layers.BatchNormalization(32)(tensorflow)和nn.BatchNorm2d(32)(pytorch)上,
我们来看tf中的定义layers.BatchNormalization()
可参考链接: link

tf.keras.layers.BatchNormalization(
    axis=-1, momentum=0.99, epsilon=0.001, center=True, scale=True,
    beta_initializer='zeros', gamma_initializer='ones',
    moving_mean_initializer='zeros',
    moving_variance_initializer='ones', beta_regularizer=None,
    gamma_regularizer=None, beta_constraint=None, gamma_constraint=None, **kwargs
)

在这里插入图片描述
而在pytorch中,链接: link.

因此正确的方式是将layers.BatchNormalization(32)改成layers.BatchNormalization(axis=3)即可。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

木星健谈能手

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值