之前我发的文章里写到过AdaBN在pytorch中的实现方法,因为自己当时并不熟悉pytorch,因为也没有去验证。现在确实用pytorch的人不少,有不少朋友也在问代码的事,正好我也准备学习pytorch,按照之前的方法试了试,效果还是不错的,准确率从92.27%提升到了96.80%。
其实从原理上讲确实很容易理解,在常规的训练中,BN的参数包含了训练集的均值和方差信息,对于同分布的数据来说,训练之后的BN参数无需调整。但是在迁移学习任务中,目标域的分布不同是一个很明显的特征,而目标域的均值和方差是容易获得的。在测试的时候通过设置BN中的track_running_stats=False可以直接将BN里源域的均值方差替换为目标域的均值方差,实际效果其实是很明显的,提升了4.53%,但是这个是6次迁移学习准确率的均值,因此实际的准确率提升要更明显。
另外,在我之前的测试中发现,只要模型中含有BN,使用AdaBN的方法均有较好的表现,其中一个shufflenet模型的准确率从96%提升到了98%,如果对性能更强的模型使用的话,99%也可能不是问题。
下面是模型代码及改动的地方
class WDCNN(nn.Module):
def __init__(self, in_channel=1, out_channel=10,AdaBN=True):
super(WDCNN, self).__init__()