目录
yolov5加载预训练,把尺度相同的层拷贝参数,否则就不拷贝,重新训练
pytorch 自己完善的 加载通道不对齐预训练,会把参数切分或者补齐进行拷贝
这个方式加载预训练模型后,新网络可以沿用旧网络的权重,可以正常训练:
加载预训练,自定义分类数,删掉分类层参数
from easydict import EasyDict as edict
opts = edict({'width_multiplier':1,
"model.classification.n_classes":2,
'attn_norm_layer': "layer_norm_2d", "model.normalization.name": "batch_norm", "model.activation.name": "swin"})
model = MobileViTv3(opts)
pth_path = r'.\mobilevitv3_1_0_0\checkpoint_ema_best.pt'
state_dict = torch.load(pth_path, map_location='cpu')
compatible_state_dict = {}
for k, v in state_dict.items():
if 'module.' in k:
compatible_state_dict[k[7:]] = v
else:
if "classifier" in k:
continue
compatible_state_dict[k] = v
model.load_state_dict(compatible_state_dict, strict=False)
yolov5加载预训练,把尺度相同的层拷贝参数,否则就不拷贝,重新训练
def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
csd = ckpt['model'].float().state_dict() # checkpoint s