keras 自定义层input_Keras编写自定义层--以GroupNormalization为例

a7a98135bd34e2d6b67b6e5cb486f15e.png

1. Group Normalization 介绍

Batch Normalization(BN)称为批量归一化,可加速网络收敛利于网络训练。但BN的误差会随着批量batch的减小而迅速增大。FAIR 研究工程师吴育昕和研究科学家何恺明合作的一篇论文 提出了一种新的与批量无关的Normalization 方法-[[1803.08494] Group Normalization]。GN 的主要工作是将通道分成组,并在每组内计算归一化的均值和方差。GN 的计算与批量大小无关,并且其准确度在各种批量大小下都很稳定。具体如下图(摘自论文):

9b93c090c942cee129507a65cdf40db1.png
BN 时的小批量会导致批量数据的统计两估算不准确,会显著增加模型误差。而无批量无关的GN方法得到的误差则相对稳定。

2. Keras自定义层方法

关于Keras中如何自定义层,可参考官方中文文档[编写你自己的层 - Keras 中文文档 ],[Keras简单自定义层例子]。自定义层中主要包括4种方法:

  • __init__(**kwargs):初始化方法,关键字参数保留,否则自定义层加载会报错。
  • build(input_shape):用于定义权重的方法
  • call(x): 自定义层具体功能的实现方法
  • get_config : 返回一个字典,获取当前层的参数信息。自定义层保存和加载时需要定义
  • compute_output_shape(input_shape):用于Keras可以自动推断shape

自定义层的保存和加载需要注意以下3点:

  • __init__(self, arg, **kwargs)初始化方法中关键字参数保留,否则自定义层加载会报错。
缺少**kwargs,TypeError: __init__() got an unexpected keyword argument 'name'
  • get_config(self)方法需要重写,否则网络结构无法保存。父类的config也需一并保存,将父类及继承类的config组装为字典形式,继承类config依据__init__方法传入的参数而定,具体如下:
缺少get_config方法,NotImplementedError: Layers with arguments in `__init__` must override `get_config`.
def 
  • load_model()需为custom_objects参数赋值
缺少custom_objects,ValueError: Unknown layer: LayerName
_custom_objects 

Keras官网提供了两种Normalization的源码,分别是:

  • 批量归一化 Keras-BatchNormalization
  • 实例归一化 Keras-InstanceNormalization

两者的不同在于IN的统计量估算是批量无关的基于单张图片单个通道,不需要用滑动平均项来记录全局的统计量,体现在源码的差异为:

# BN code
# IN code

解释:

  • 所有自定义层都需要继承基础层Layer,并添加super().__init__(**kwargs)
  • **kwargs代表以字典方式继承父类
  • self.add_weight()是继承层Layer的方法,用于为变量添加权重,其中有参数trainable代表该参数的权重是否为可训练权重; 若trainable==True时,会执行self._trainable_weights.append(weight).
  • BN中需要添加moving_mean/variance滑动平均项的权重,且需要设置trainable==False,即为非训练参数。
  • self.add_update()用于更新滑动平均项
  • K.in_train_phase()针对训练状态选择不同的mean/variance计算BN

3. 定义Group Normalization层

源代码位置Bingohong/GroupNormalization-tensorflow-keras,里面包含了2个GN文件,分别是tensorflow和keras的实现版本,其中都包含了moving_average操作。

其实关于GN操作,是否需要apply moving_average是值得商榷的,论文中貌似没有明确提及,其他实现版本中的实现都是无moving_average操作。但通过对比IN、BN和GN特点及后期的实验对比,觉得GN应该是不需要moving_average操作的。因此这部分内容包括:

  • 主要介绍有moving_average操作的GN层的定义过程,而无moving_average操作时,只需要将对应的代码去掉。
  • 使用BN/GN_with_moving_average/GN_without_moving_average3种Normalization方法,对比U-net的实验结果。

keras GN层

完整代码在这里,以下仅解释部分关键代码。

# GN_with_moving_average code

实验对比结果

实验日志位于compare_log,包含3个文件:

  • train_bn.log -> unet+bn日志
  • train_gn_ema.log -> unet+gn(有moving_average操作)
  • train_gn_noema.log -> unet+gn(无moving_average操作)

结果说明:

  • gn without moving average 得到的val_loss会更低,可达到 0.2左右
  • gn with moving average 有时会一直存在很高的val_loss, 所以我觉得可能GN并不需要 apply moving average
  • bn 得到的val_ loss约为 0.26, 高于gn without moving average.

欢迎大家批评指正~ 谢谢谢谢~~~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值