如何冻结
一般要冻结特征提取层(pretrain layer)的bn 还有一些自己定义的bn不应该冻结 因此在自己的model里重写train
#示例程序 在自己写的model里添加
def train(self, mode=True):
"""
Override the default train() to freeze the BN parameters
"""
super(fintuneNet, self).train(mode)
if self.args.freeze_bn and mode==True:
print('modify train in ')
self.branch_cnn.apply(self.fix_bn)
#branch_cnn特征提取层 一般是一个 nn.Sequential()
return self
def fix_bn(self,m):
classname = m.__class__.__name__
# print('classname',classname)
if classname.find('BatchNorm') != -1:
m.eval()
m.track_running_stats=False
for i, (name1, p) in enumerate(m.named_parameters()):
p.requires_grad = False
train val 区别
最主要区别有两个
1、train状态下参数更新、使用是当前batch的统计量。runing means 不使用 但是forworad时候就更新了
2、val状态下参数不更新 、使用的是train的时候记录的runing means
3、track_running_stats 标识是train状态下是否更新 running means
推荐好文:Pytorch BN(BatchNormal)计算过程与源码分析和train与eval的区别_一只皮皮虾x的博客-CSDN博客