train.py代码学习
train.py是YOLOv7的训练脚本。总体代码流程如下:准备工作:数据 + 模型 + 学习率 + 优化器;训练过程:一个训练过程(不包括数据准备),会轮询多次训练集,每次称为一个epoch,每个epoch又分为多个batch来训练。流程先后拆解成:开始训练、训练一个epoch前、训练一个batch前、训练一个batch后、训练一个epoch后、评估验证集、结束训练。
提示:以下是本篇文章正文内容,下面案例可供参考
一、执行train()函数
1.引入参数
代码如下(示例):
hyp: 超参数,不使用超参数进化的前提下也可以从opt中获取
opt: 全部的命令行参数
device: 指的是装载程序的设备
tb_writer=None是指将tensorboard的SummaryWriter类的实例化对象赋值为None。如果不需要使用tensorboard,可以将tb_writer设置为None。
def train(hyp, opt, device, tb_writer=None):
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
save_dir, epochs, batch_size, total_batch_size, weights, rank, freeze = \
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, opt.freeze
2.创建训练权重目录和保存路径
代码如下(示例):
wdir.mkdir(parents=True, exist_ok=True)是用于创建文件夹的语句。其中,wdir是文件夹的路径,parents=True表示如果父文件夹不存在则创建父文件夹,exist_ok=True表示如果文件夹已经存在则不会抛出异常
wdir = save_dir / 'weights' #####记录训练日志保存路径
wdir.mkdir(parents=True, exist_ok=True) # make dir
last = wdir / 'last.pt'
best = wdir / 'best.pt'
results_file = save_dir / 'results.txt'
系统会产生两个模型,一个是last.pt,一个是best.pt。顾名思义,last.pt即为训练最后一轮产生的模型,而best.pt是训练过程中,效果最好的模型。
3.设置参数的保存路径
yaml.safe_load(f)是加载yaml的标准函数接口,保存超参数为yaml配置文件。 yaml.safe_dump()是将yaml文件序列化,保存命令行参数为yaml配置文件。
vars(opt) 的作用是把数据类型是Namespace的数据转换为字典的形式。
with open(save_dir / 'hyp.yaml', 'w') as f:
yaml.dump(hyp, f, sort_keys=False)
with open(save_dir / 'opt.yaml', 'w') as f:
yaml.dump(vars(opt), f, sort_keys=False)
将超参数和程序中使用的选项保存到两个单独的YAML文件中。第一个文件名为hyp.yaml,包含超参数。第二个文件名为opt.yaml,包含程序中使用的选项。
4.创建训练权重目录和保存路径
代码如下(示例):
wdir.mkdir(parents=True, exist_ok=True)是用于创建文件夹的语句。其中,wdir是文件夹的路径,parents=True表示如果父文件夹不存在则创建父文件夹,exist_ok=True表示如果文件夹已经存在则不会抛出异常
wdir = save_dir / 'weights' #####记录训练日志保存路径
wdir.mkdir(parents=True, exist_ok=True) # make dir
last = wdir / 'last.pt'
best = wdir / 'best.pt'
results_file = save_dir / 'results.txt'
系统会产生两个模型,一个是last.pt,一个是best.pt。顾名思义,last.pt即为训练最后一轮产生的模型,而best.pt是训练过程中,效果最好的模型。
5.加载日志信息
loggers = {'wandb': None} # loggers dict
if rank in [-1, 0]:
opt.hyp = hyp # add hyperparameters
run_id = torch.load(weights, map_location=device).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
loggers['wandb'] = wandb_logger.wandb
data_dict = wandb_logger.data_dict
if wandb_logger.wandb:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
WandbLogger是一个Python类,用于将模型训练过程中的超参数和输出指标记录到Weights & Biases(wandb)中。它能够自动记录模型训练过程中的超参数和输出指标,然后可视化和比较结果,并快速与同事共享结果。通过wandb,能够给你的机器学习项目带来强大的交互式可视化调试体验,能够更好地理解模型的性能和行为。
总结
学习过程的记录,如有错误敬请指正,谢谢。