print("# --------------- restore ---------------#") if (is_restore): print("restore from ckpt=", restore_ckpt) checkpoint = torch.load(restore_ckpt) # ,encoding="latin1") print("=> resuming from checkpoint ") print(restore_ckpt) print("==> checkpoint [epoch={}, iteration={}]".format(checkpoint["epoch"], checkpoint["iteration"])) args.restore_epoch_iter = "Epoch=[%d]_Iter=[%d]" % (checkpoint["epoch"], checkpoint["iteration"]) restore_dir = os.path.dirname(restore_ckpt) restore_main_dir = os.path.dirname(restore_dir) #ckpt = checkpoint["state_dict"] #------------------------- 注意这种方式会遗漏bn曾的running_mean,因为running_mean不属于named_parameters----------------------# # if (is_face and "isface" not in restore_main_dir): # for name, param in model.named_parameters(): # if (name in ckpt.keys()): # tmp = param.data.clone() # param.data = ckpt[name] # diff = torch.sum(torch.abs(tmp - param.data)) # print("%s " % name, diff) # else: # model.load_state_dict(checkpoint["state_dict"]) # ------------------------- 注意这种方式会遗漏bn曾的running_mean----------------------# if(is_face): state_dict=checkpoint["state_dict"] own_state = model.state_dict() print("------------------------------------ restore list --------------------------------") for name, param in state_dict.items(): if("_isface" in name): continue if isinstance(param, torch.nn.Parameter): # backwards compatibility for serialized parameters param = param.data own_state[name].copy_(param) print("restore : ",name) else: model.load_state_dict(checkpoint["state_dict"]) if(is_face): print("#---------------------setting isface------------------------#") #for idx,m in enumerate(model.modules()): for name, m in model.named_modules(): # freeze bn https://blog.51cto.com/u_16213585/8592957 # model.modules 强制遍历所有子层和子层的子层 https://blog.csdn.net/weixin_48076759/article/details/131048684 print(name,m) if isinstance(m, nn.BatchNorm2d) and '_isface' not in name: #print("Bn eval") m.eval() for name,param in model.named_parameters(): if("_isface" not in name): param.requires_grad=False print("------------------------------------ train list -----------------------------------") for name,param in model.named_parameters(): print(name.ljust(50),str(param.shape).ljust(30), param.requires_grad)
torch固定部分权重,训练其他权重方法
最新推荐文章于 2024-10-03 09:02:12 发布