加载部分预训练模型
#加载model,model是自己定义好的模型
resnet50 = models.resnet50(pretrained=True)
model =Net(...)
#读取参数
pretrained_dict =resnet50.state_dict()
model_dict = model.state_dict()
#将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {
k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)
在Mixformer中的具体实现:
msvit_spec = config.MODEL.BACKBONE#超参列表
msvit = ConvolutionalVisionTransformer(in_chans=3,act_layer=QuickGELU,norm_layer=partial(LayerNorm, eps=1e-5),init=getattr(msvit_spec, 'INIT', 'trunc_norm'),spec=msvit_spec)
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++&#