if __name__=="__main__":
#模型
model = YoloBody()
model_path = 'model_data/yolo_swin_back.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path, map_location=device)
#预加载参数
a = {}
print(pretrained_dict.__sizeof__())
print(pretrained_dict['model'])
x=0
y=0
for k, v in pretrained_dict.items():
if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
a[k] = v
x=x+1
else: y=y+1
model_dict.update(a)
model.load_state_dict(model_dict)
print(x,y)
注:可以融合多个模型参数