使用场景:需要完全冻结某部分的 weight 与 BN 层
加载预训练模型时,如果只将 para.requires_grad = False ,并不能完全冻结模型的参数,因为模型中的 BN 层并不随 loss.backward() 与 optimizer.step() 来更新,而是在模型 forward 的过程中基于动量来更新,因此需要每个 forward 之前冻结 BN 层:
完整的冻结方式如下:
'''
一堆代码
'''
# 冻结BN
def freeze_bn(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval()
'''
一堆代码
'''
freeze_state_dict = torch.load(opt.loadckpt_freeze)
frozen_list = [k for k, _ in freeze_state_dict['state_dict'].items() if k in model_dict]
# 先冻结除了 BN 以外的参数
for param in model.named_parameters():
if param[0] in frozen_list: # 需要冻结的参数列表
param[1].requires_grad = False
# 优化器优化的参数只包含需要梯度更新的参数
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr, betas=(0.9,0.999))
'''
一堆代码
'''
for epoch in range(opt.epoch):
model.train()
optimizer.zero_grad()
# 冻结BN
model.apply(freeze_bn)
# 前向传播
output = model(input)
loss = loss_F(gt, output)
loss.backward()
optimizer.step()