1.直接加载网络
import torch
pthfile = r'E:\models\squeezenet1_1.pth'
net = torch.load(pthfile)
print(net)
方法2:
import torch
def remove_prefix(state_dict, prefix):
''' Old style model is stored with all names of parameters
share common prefix 'module.' '''
# logger.info('remove prefix \'{}\''.format(prefix))
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
return {f(key): value for key, value in state_dict.items()}
def load_model(model, pretrained_path, load_to_cpu):
print('Loading pretrained model from {}'.format(pretrained_path))
if load_to_cpu:
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
else:
device = torch.cuda.current_device()
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
if "state_dict" in pretrained_dict.keys():
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
else:
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
check_keys(model, pretrained_dict)
# model.load_state_dict(pretrained_dict,strict=False)
model.load_state_dict(pretrained_dict)
return model
保存网络模型参数方法1:
save_net=net
if hasattr(net, 'module'):
save_net=net.module
torch.save(save_net.state_dict(), f'{save_folder}/{map_score:.4f}_{chazhun_Rate:.4f}_{last_loss:.4f}_{lr:.1e}_{epoch).pth')
保存网络模型参数方法2:
path_module = os.path.join(self.output, f"{acc2:.4f}_{global_step}_net.pth")
if isinstance(backbone,torch.nn.parallel.DistributedDataParallel):
print("save backbone.module")
torch.save(backbone.module.state_dict(), path_module)
else:
torch.save(backbone.state_dict(), path_module)