1. Group Normalization 介绍
Batch Normalization(BN)称为批量归一化,可加速网络收敛利于网络训练。但BN的误差会随着批量batch的减小而迅速增大。FAIR 研究工程师吴育昕和研究科学家何恺明合作的一篇论文 提出了一种新的与批量无关的Normalization 方法-[[1803.08494] Group Normalization]。GN 的主要工作是将通道分成组,并在每组内计算归一化的均值和方差。GN 的计算与批量大小无关,并且其准确度在各种批量大小下都很稳定。具体如下图(摘自论文):
2. Keras自定义层方法
关于Keras中如何自定义层,可参考官方中文文档[编写你自己的层 - Keras 中文文档 ],[Keras简单自定义层例子]。自定义层中主要包括4种方法:
- __init__(**kwargs):初始化方法,关键字参数保留,否则自定义层加载会报错。
- build(input_shape):用于定义权重的方法
- call(x): 自定义层具体功能的实现方法
- get_config : 返回一个字典,获取当前层的参数信息。自定义层保存和加载时需要定义
- compute_output_shape(input_shape):用于Keras可以自动推断shape
自定义层的保存和加载需要注意以下3点:
- __init__(self, arg, **kwargs)初始化方法中关键字参数保留,否则自定义层加载会报错。
缺少**kwargs,TypeError: __init__() got an unexpected keyword argument 'name'
- get_config(self)方法需要重写,否则网络结构无法保存。父类的config也需一并保存,将父类及继承类的config组装为字典形式,继承类config依据__init__方法传入的参数而定,具体如下:
缺少get_config方法,NotImplementedError: Layers with arguments in `__init__` must override `get_config`.
def
- load_model()需为custom_objects参数赋值
缺少custom_objects,ValueError: Unknown layer: LayerName
_custom_objects
Keras官网提供了两种Normalization的源码,分别是:
- 批量归一化 Keras-BatchNormalization
- 实例归一化 Keras-InstanceNormalization
两者的不同在于IN的统计量估算是批量无关的基于单张图片单个通道,不需要用滑动平均项来记录全局的统计量,体现在源码的差异为:
# BN code
# IN code
解释:
- 所有自定义层都需要继承基础层Layer,并添加super().__init__(**kwargs)
- **kwargs代表以字典方式继承父类
- self.add_weight()是继承层Layer的方法,用于为变量添加权重,其中有参数trainable代表该参数的权重是否为可训练权重; 若trainable==True时,会执行self._trainable_weights.append(weight).
- BN中需要添加moving_mean/variance滑动平均项的权重,且需要设置trainable==False,即为非训练参数。
- self.add_update()用于更新滑动平均项
- K.in_train_phase()针对训练状态选择不同的mean/variance计算BN
3. 定义Group Normalization层
源代码位置Bingohong/GroupNormalization-tensorflow-keras,里面包含了2个GN文件,分别是tensorflow和keras的实现版本,其中都包含了moving_average操作。
其实关于GN操作,是否需要apply moving_average是值得商榷的,论文中貌似没有明确提及,其他实现版本中的实现都是无moving_average操作。但通过对比IN、BN和GN特点及后期的实验对比,觉得GN应该是不需要moving_average操作的。因此这部分内容包括:
- 主要介绍有moving_average操作的GN层的定义过程,而无moving_average操作时,只需要将对应的代码去掉。
- 使用BN/GN_with_moving_average/GN_without_moving_average3种Normalization方法,对比U-net的实验结果。
keras GN层
完整代码在这里,以下仅解释部分关键代码。
# GN_with_moving_average code
实验对比结果
实验日志位于compare_log,包含3个文件:
- train_bn.log -> unet+bn日志
- train_gn_ema.log -> unet+gn(有moving_average操作)
- train_gn_noema.log -> unet+gn(无moving_average操作)
结果说明:
- gn without moving average 得到的val_loss会更低,可达到 0.2左右
- gn with moving average 有时会一直存在很高的val_loss, 所以我觉得可能GN并不需要 apply moving average
- bn 得到的val_ loss约为 0.26, 高于gn without moving average.
欢迎大家批评指正~ 谢谢谢谢~~~