前情:在做实验的时候一开始用的是网上的代码,最后保存两个模型:model.pth,optimizer.pth,但是最终的实验结果和预期相差太大,准备在最好的模型基础上继续训练。
model.pth和optimizer.pth的区别
这里不讲这两个的具体作用,主要从用法开始讲:
1.如果你只是在已得到的最佳结果进行测试的话,只需要 model.pth 即可。
2.如果是想在已有的最佳结果上面继续训练,继续提升结果,那就需要到 model.pth 和optimizer.pth。
如何将已得到的模型参数(model.pth)和优化器(optimizer.pth)参数代入到工程代码中继续训练
需要修改以下几个地方:
1.在创建model之后加载model.pth (具体添加方法如以代码中的备注1)
2.在创建optimizer之后加载optimizer.pth(具体添加方法如代码中备注2)
3.给出model.pth和optimizer.pth的地址,也就是代码中的cfg.MODEL.CONTINUE.MODEL,cfg.MODEL.CONTINUE.OPTIMIZER,这个就是模型具体地址的位置。
4.非常重要,要将model从cpu转换到cuda,否则会报错
# encoding: utf-8
def train(cfg):
model = build_fcn_model(cfg)
#4.非常重要,要将model从cpu转换到cuda,否则会报错
model=model.cuda()
# 1.在model下面添加model.pth ,其中cfg.MODEL.CONTINUE.MODEL为模型保存的位置
model.load_state_dict(torch.load(cfg.MODEL.CONTINUE.MODEL))
optimizer = make_optimizer(cfg, model)
# 2.在optimizer 下面添加optimizer .pth ,其中cfg.MODEL.CONTINUE.optimizer 为模型保存的位置
optimizer.load_state_dict(torch.load(cfg.MODEL.CONTINUE.OPTIMIZER))
arguments = {}
data_loader = make_data_loader(cfg, is_train=True)
val_loader = make_data_loader(cfg, is_train=False)
do_train(
cfg,
model,
data_loader,
val_loader,
optimizer,
cross_entropy2d,
)
def main():
parser = argparse.ArgumentParser(description="PyTorch Training")
parser.add_argument(
"--config_file", default="", help="path to config file", type=str
)
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
if args.config_file != "":
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
output_dir = cfg.OUTPUT_DIR
if output_dir and not os.path.exists(output_dir):
mkdir(output_dir)
logger = setup_logger("Model", output_dir, 0)
logger.info("Using {} GPUS".format(num_gpus))
logger.info(args)
if args.config_file != "":
logger.info("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, 'r') as cf:
config_str = "\n" + cf.read()
logger.info(config_str)
logger.info("Running with config:\n{}".format(cfg))
train(cfg)
if __name__ == '__main__':
main()