def func_converse():
heads = {'hm': 1, 'vaf': 2, 'haf': 1}
model = deletedres.ResBiNet(n_classes=1, heads=heads)
save_model = torch.load(r'./resbinet/ressgd_child.pth')
model_dict = model.state_dict()
state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)
torch.save(model.state_dict(), 'sgd.pth')
04-13
1056
08-17
6761
05-25
3011