工作了之后发现迁移学习/领域自适应的落地场景是超出想象的多。最近一直在搞domain adaption相关的东西,工(mo)作(yu)之余分享一个很容易被忽略的小trick:如果来自source 和 target domain的数据分布很不一致,试试在target domain 数据 forward的时候,关掉BN的自动更新。
# 关掉模型的BN(pytorch写法)
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.eval()
这个trick不一定在所有的数据上都有效,但是解决了我自己数据集上的问题,以下详细讲一下为啥要这么搞。
我的baseline是使用deep CORAL的思想来做图片分类的领域自适应,但是训着训着,我发现虽然网络在source domain上单独训练时表现很好,但是一旦加入target domain的数据训练,哪怕只有forward操作,都会导致有监督分类输出的结果不收敛,甚至把预训练的权重都带跑偏。
一开始我以为是torch反向传播机制的问题,后来发现,即使在target domain data forward的时候关掉网络的grad都会导致结果不收敛,这简直太诡异了,这个奇怪的bug困扰了我两天,在我差点要去看torch的源码时,才想到是不是BN的问题。
关掉grad时,只要模型还处在`model.train()`状态,就会持续更新BN的参数,当source 和 domain的分布很不一样的时候,一次forward就会让BN的参数发生很大的变化,进而使得输入的数据不断波动,让后面的结果一直不收敛。
后来我也查阅了一些资料,有关于在做transfer和domain adaption的时候应不应该关掉BN的问题,其实真的不一定,俗话说实践出真知,还是应该根据不同的场景多试试,毕竟也就是几行代码的事。
(奇怪的调参经验又增加了!)
参考资料:
https://stackoverflow.com/questions/44609533/can-we-use-batch-normalization-with-transfer-learning-for-an-instance-with-diffe