detrex项目地址
是一个包含多种detection transformer模型的框架。
该框架支持断点续传,重新训练,但只能从最后一个iter的模型加载并训练。
例如我总共需要训练59999个iter,已经训练29999个iter,程序中断,最后一个模型的保存路径是
'../output/029999.pth'
但我想从24999这个模型,即
'../output/024999.pth'
继续训练。
需要修改三个地方。
1.训练config文件
将代码中的
train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl"
修改为
train.init_checkpoint = '../output/024999.pth'
train.init_iter=24999
这里是自己想要开始训练的模型路径及对应的iter。
2.tools/train_net.py
搜索模型加载部分代码
if args.resume and checkpointer.has_checkpoint():
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration
start_iter = trainer.iter + 1
else:
start_iter = 0
将之改为
if args.resume and checkpointer.has_checkpoint():
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration
start_iter = trainer.iter + 1
elif args.resume and cfg.train.init_iter:
start_iter = int(cfg.train.init_iter) + 1
else:
start_iter = 0
3.detrex/lib/python3.9/site-packages/fvcore/common/checkpoint.py
这一步最为关键,需要找到detrex安装环境中的python包,修改模型加载部分。可以从train_net.py的*checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)*这一函数直接跳转过去。
将checkpoint.py中的resume_or_load函数找到,并将
if resume and self.has_checkpoint():
path = self.get_checkpoint_file()
return self.load(path)
else:
return self.load(path, checkpointables=[])
修改为
if resume and self.has_checkpoint():
path = self.get_checkpoint_file()
return self.load(path)
elif resume and self.path_manager.exists(path):
return self.load(path)
else:
return self.load(path, checkpointables=[])
搞定。
剩下的按照说明照常运行即可。