- 检查输入数据(train与test)是否经过了归一化
- 设置model.eval()时,网络中所有bn层里超参 track_running_stats = False
但是,利用这种方式,模型在进行infer的时候必须要保持和训练时大致相当的batchsize,否则还是无法保持训练时的精度。 - 直接将BN层替换掉,尝试替换为layernorm或groupnorm。在我当时的任务中,将BN全部替换为layernorm会导致模型无法收敛;将BN层全部替换为groupnorm可以保证infer时精度与train时相当,且infer时batchsize可以为1,只是精度上相较于BN会降一些。
另:替换为groupnorm时,参数设置为32
def set_bn_fix_1(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.track_running_stats = False
def test(data, model, criterion):
print("******** Testing ********")
with torch.no_grad():
model.eval()
model.apply(set_bn_fix_1)