训练的时候,发现epochs设置多了,训练中途发现收敛效果还可以,不用继续跑太多轮,于是想缩减epochs。但没找到解决的帖子…
修改步骤:
1.首先train文件的参数改一下:
- opochs改成减少后的轮数,比如先前是200,这次改成100
- model路径改成上次训练的last.pt
- 为了防止路径错误,建议全部使用绝对路径,然后把"\“都换成”\"
from ultralytics import YOLO
if __name__ == '__main__':
model = YOLO("runs\\detect\\train\\weights\\last.pt")
results = model.train(data="C:\\Users\\Administrator\\Desktop\\ultralytics-main\\ultralytics-main\\ultralytics\\datasets\\mask\\data.yaml", epochs=100, batch=4, workers=2, resume=True, device=0)
2. 在trainer.py的__init__() 构造函数里,先用一个变量接收config的epochs
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initializes the BaseTrainer class.
Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
################修改处################
self.resume_epochs = self.args.epochs# 添加变量接收config,即你规定的epochs
######################################
self.check_resume(overrides)
...
...
...
3. 在trainer.py的check_resume函数中恢复self.epochs
def check_resume(self, overrides):
...
...
if resume:
try:
...
...
resume = True
self.args = get_cfg(ckpt_args)
############修改处#####################
self.args.epochs = self.resume_epochs #重新覆盖self.args.epochs数值
#######################################
self.args.model = str(last) # reinstate model
...
...
...
4.结束
原理:
其实是在resume里,执行完self.args后,args参数全部都从之前训练的断点中恢复了,被覆盖掉了
先看一下原生代码执行顺序:
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
...
self.args = get_cfg(cfg, overrides)
self.resume_epochs = self.args.epochs
self.check_resume(overrides)
...
self.epochs = self.args.epochs
self.args = get_cfg(cfg, overrides)
先从config接收参数,也就是我们自己规定的参数self.check_resume(overrides)
检查并恢复断点,也就是resume为True的话,参数从上次训练的pt文件种读取,问题就出现在这里了。可以看到这里有一句self.args = get_cfg(ckpt_args)
,执行后,刚才的self.args就被覆盖了,这也是为什么我们自己设置的参数没用,因为被覆盖了。可以看下这段代码:
def check_resume(self, overrides):
"""Check if resume checkpoint exists and update arguments accordingly."""
resume = self.args.resume
if resume:
try:
...
...
resume = True
#######从断点读取参数#########
self.args = get_cfg(ckpt_args)
...
...
- 接着,执行
self.epochs = self.args.epochs
,epochs上限被最终赋值。
所以,做法就是在一开始用一个变量(我的是resume_epochs)备份我们设置的epochs值。然后再在恢复断点的时候,也就是epochs被覆盖了的时候,再用这个变量还回去。