YOLOv5代码详解(train.py部分)

1. train.py

1.1 使用nvidia的apex接口计算混合精度训练

mixed_precision = True
try:  # Mixed precision training https://github.com/NVIDIA/apex
    from apex import amp
except:
    print('Apex recommended for faster mixed precision training: https://github.com/NVIDIA/apex')
    mixed_precision = False  # not installed

   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

1.2 获取文件路径

wdir = 'weights' + os.sep  # weights dir
os.makedirs(wdir, exist_ok=True)
last = wdir + 'last.pt'
best = wdir + 'best.pt'
results_file = 'results.txt'

   
   
  • 1
  • 2
  • 3
  • 4
  • 5

1.3 获取数据路径

# Configure
    init_seeds(1)
    with open(opt.data) as f:
        data_dict = yaml.load(f, Loader=yaml.FullLoader)  # model dict
    train_path = data_dict['train']
    test_path = data_dict['val']
    nc = 1 if opt.single_cls else int(data_dict['nc'])  # number of classes

   
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

1.4 移除之前的结果

# Remove previous results
    for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
        os.remove(f)

   
   
  • 1
  • 2
  • 3

1.5 创建模型

# Create model
    model = Model(opt.cfg).to(device)
    assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (opt.data, nc, opt.cfg, model.md['nc'])
    model.names = data_dict['names']

   
   
  • 1
  • 2
  • 3
  • 4

assert是一个判断表达式,在assert后面成立时创建模型。
参考链接

1.6 检查训练和测试图片尺寸

# Image sizes
    gs = int(max(model.stride))  # grid size (max stride)
    imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size]  # verify imgsz are gs-multiples

   
   
  • 1
  • 2
  • 3

1.7 设置优化器参数

# Optimizer
    nbs = 64  # nominal batch size
    accumulate = max(round(nbs / batch_size), 1)  # accumulate loss before optimizing
    hyp['weight_decay'] *= batch_size * accumulate / nbs  # scale weight_decay
    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
    for k, v in model.named_parameters():
        if v.requires_grad:
            if '.bias' in k:
                pg2.append(v)  # biases
            elif '.weight' in k and '.bn' not in k:
                pg1.append(v)  # apply weight decay
            else:
                pg0.append(v)  # all else
optimizer <span class="token operator">=</span> optim<span class="token punctuation">.</span>Adam<span class="token punctuation">(</span>pg0<span class="token punctuation">,</span> lr<span class="token operator">=</span>hyp<span class="token punctuation">[</span><span class="token string">'lr0'</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token keyword">if</span> opt<span class="token punctuation">.</span>adam <span class="token keyword">else</span> \
    optim<span class="token punctuation">.</span>SGD<span class="token punctuation">(</span>pg0<span class="token punctuation">,</span> lr<span class="token operator">=</span>hyp<span class="token punctuation">[</span><span class="token string">'lr0'</span><span class="token punctuation">]</span><span class="token punctuation">,</span> momentum<span class="token operator">=</span>hyp<span class="token punctuation">[</span><span class="token string">'momentum'</span><span class="token punctuation">]</span><span class="token punctuation">,</span> nesterov<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
optimizer<span class="token punctuation">.</span>add_param_group<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span><span class="token string">'params'</span><span class="token punctuation">:</span> pg1<span class="token punctuation">,</span> <span class="token string">'weight_decay'</span><span class="token punctuation">:</span> hyp<span class="token punctuation">[</span><span class="token string">'weight_decay'</span><span class="token punctuation">]</span><span class="token punctuation">}</span><span class="token punctuation">)</span>  <span class="token comment"># add pg1 with weight_decay</span>
optimizer<span class="token punctuation">.</span>add_param_group<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span><span class="token string">'params'</span><span class="token punctuation">:</span> pg2<span class="token punctuation">}</span><span class="token punctuation">)</span>  <span class="token comment"># add pg2 (biases)</span>
<span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'Optimizer groups: %g .bias, %g conv.weight, %g other'</span> <span class="token operator">%</span> <span class="token punctuation">(</span><span class="token builtin">len</span><span class="token punctuation">(</span>pg2<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>pg1<span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token builtin">len</span><span class="token punctuation">(</span>pg0<span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">del</span> pg0<span class="token punctuation">,</span> pg1<span class="token punctuation">,</span> pg2
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

Optimizer groups: 102 .bias, 108 conv.weight, 99 other
del并非删除数据,而是删除变量(删除指向数据的链接)参考链接

1.8 加载预训练模型和权重,并写入训练结果到results.txt

# Load Model
    google_utils.attempt_download(weights)
    start_epoch, best_fitness = 0, 0.0
    if weights.endswith('.pt'):  # pytorch format
        ckpt = torch.load(weights, map_location=device)  # load checkpoint
    <span class="token comment"># load model</span>
    <span class="token keyword">try</span><span class="token punctuation">:</span>
        ckpt<span class="token punctuation">[</span><span class="token string">'model'</span><span class="token punctuation">]</span> <span class="token operator">=</span> <span class="token punctuation">{<!-- --></span>k<span class="token punctuation">:</span> v <span class="token keyword">for</span> k<span class="token punctuation">,</span> v <span class="token keyword">in</span> ckpt<span class="token punctuation">[</span><span class="token string">'model'</span><span class="token punctuation">]</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>items<span class="token punctuation">(</span><span class="token punctuation">)</span>
                         <span class="token keyword">if</span> model<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">[</span>k<span class="token punctuation">]</span><span class="token punctuation">.</span>shape <span class="token operator">==</span> v<span class="token punctuation">.</span>shape<span class="token punctuation">}</span>  <span class="token comment"># to FP32, filter</span>
        model<span class="token punctuation">.</span>load_state_dict<span class="token punctuation">(</span>ckpt<span class="token punctuation">[</span><span class="token string">'model'</span><span class="token punctuation">]</span><span class="token punctuation">,</span> strict<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    <span class="token keyword">except</span> KeyError <span class="token keyword">as</span> e<span class="token punctuation">:</span>
        s <span class="token operator">=</span> <span class="token string">"%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s."</span> \
            <span class="token operator">%</span> <span class="token punctuation">(</span>opt<span class="token punctuation">.</span>weights<span class="token punctuation">,</span> opt<span class="token punctuation">.</span>cfg<span class="token punctuation">,</span> opt<span class="token punctuation">.</span>weights<span class="token punctuation">)</span>
        <span class="token keyword">raise</span> KeyError<span class="token punctuation">(</span>s<span class="token punctuation">)</span> <span class="token keyword">from</span> e

    <span class="token comment"># load optimizer</span>
    <span class="token keyword">if</span> ckpt<span class="token punctuation">[</span><span class="token string">'optimizer'</span><span class="token punctuation">]</span> <span class="token keyword">is</span> <span class="token operator">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
        optimizer<span class="token punctuation">.</span>load_state_dict<span class="token punctuation">(</span>ckpt<span class="token punctuation">[</span><span class="token string">'optimizer'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
        best_fitness <span class="token operator">=</span> ckpt<span class="token punctuation">[</span><span class="token string">'best_fitness'</span><span class="token punctuation">]</span>

    <span class="token comment"># load results</span>
    <span class="token keyword">if</span> ckpt<span class="token punctuation">.</span>get<span class="token punctuation">(</span><span class="token string">'training_results'</span><span class="token punctuation">)</span> <span class="token keyword">is</span> <span class="token operator">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
        <span class="token keyword">with</span> <span class="token builtin">open</span><span class="token punctuation">(</span>results_file<span class="token punctuation">,</span> <span class="token string">'w'</span><span class="token punctuation">)</span> <span class="token keyword">as</span> <span class="token builtin">file</span><span class="token punctuation">:</span>
            <span class="token builtin">file</span><span class="token punctuation">.</span>write<span class="token punctuation">(</span>ckpt<span class="token punctuation">[</span><span class="token string">'training_results'</span><span class="token punctuation">]</span><span class="token punctuation">)</span>  <span class="token comment"># write results.txt</span>

    start_epoch <span class="token operator">=</span> ckpt<span class="token punctuation">[</span><span class="token string">'epoch'</span><span class="token punctuation">]</span> <span class="token operator">+</span> <span class="token number">1</span>
    <span class="token keyword">del</span> ckpt
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

1.9 把混合精度训练加载入训练中

若之前mixed_precision=False则不会加入混合精度训练至训练中。

if mixed_precision:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)

 
 
  • 1
  • 2

opt_level=‘O1’ ,这里不是‘零1’,而是“O1”(偶1)

1.10 设置cosine调度器,定义学习率衰减

# Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    scheduler.last_epoch = start_epoch - 1  # do not move

 
 
  • 1
  • 2
  • 3
  • 4

1.11 定义并初始化分布式训练

# Initialize distributed training
    if device.type != 'cpu' and torch.cuda.device_count() > 1 and torch.distributed.is_available():
        dist.init_process_group(backend='nccl',  # distributed backend
                                init_method='tcp://127.0.0.1:9999',  # init method
                                world_size=1,  # number of nodes
                                rank=0)  # node rank
        model = torch.nn.parallel.DistributedDataParallel(model)

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

当满足上面三个条件(非CPU、cuda设备大于1、分布式torch可用)时,就可以进行分布式训练了。
笔者是用一张卡来训练的,不满足这个条件,没有用到分布式训练。—————————————————————————————————————————
nn.distributedataparallel()支持模型多进程并行,适用于单机或多机,每个进程都具备独立的优化器,执行自己的更新过程。
参考链接

1.12 载入训练集和测试集

# Trainloader
    dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
                                            hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect)
    mlc = np.concatenate(dataset.labels, 0)[:, 0].max()  # max label class
    assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (mlc, nc, opt.cfg)
<span class="token comment"># Testloader</span>
testloader <span class="token operator">=</span> create_dataloader<span class="token punctuation">(</span>test_path<span class="token punctuation">,</span> imgsz_test<span class="token punctuation">,</span> batch_size<span class="token punctuation">,</span> gs<span class="token punctuation">,</span> opt<span class="token punctuation">,</span>
                                        hyp<span class="token operator">=</span>hyp<span class="token punctuation">,</span> augment<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> cache<span class="token operator">=</span>opt<span class="token punctuation">.</span>cache_images<span class="token punctuation">,</span> rect<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span><span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

dataloader和testloader不同之处在于:

  1. testloader:没有数据增强,rect=True(大概是测试图片保留了原图的长宽比)
  2. dataloader:数据增强,保留了矩形框训练。

1.13 模型参数

# Model parameters
    hyp['cls'] *= nc / 80.  # scale coco-tuned hyp['cls'] to current dataset
    model.nc = nc  # attach number of classes to model
    model.hyp = hyp  # attach hyperparameters to model
    model.gr = 1.0  # giou loss ratio (obj_loss = 1.0 or giou)
    model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device)  # attach class weights

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

1.14 类别统计

# Class frequency
    labels = np.concatenate(dataset.labels, 0)
    c = torch.tensor(labels[:, 0])  # classes
    # cf = torch.bincount(c.long(), minlength=nc) + 1.
    # model._initialize_biases(cf.to(device))
    if tb_writer:
        plot_labels(labels)
        tb_writer.add_histogram('classes', c, 0)

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

1.15 检查anchors是否存在

# Check anchors
    if not opt.noautoanchor:
        check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)

 
 
  • 1
  • 2
  • 3

1.16 指数移动平均

# Exponential moving average
    ema = torch_utils.ModelEMA(model)

 
 
  • 1
  • 2

在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。参考博客

1.17 开始训练

1.17.1 获取参数

获取开始时间,batch size数量,epochs数量,图片数量。

# Start training
    t0 = time.time() # start time
    nb = len(dataloader)  # number of batches
    n_burn = max(3 * nb, 1e3)  # burn-in iterations, max(3 epochs, 1k iterations)
    maps = np.zeros(nc)  # mAP per class
    results = (0, 0, 0, 0, 0, 0, 0)  # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
    print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
    print('Using %g dataloader workers' % dataloader.num_workers)
    print('Starting training for %g epochs...' % epochs)
    # torch.autograd.set_detect_anomaly(True)

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

1.17.2 训练开始

加载图片权重(可选),定义进度条,设置偏差Burn-in,使用多尺度,前向传播,损失函数,反向传播,优化器,打印进度条,保存训练参数至tensorboard,计算mAP,保存结果到results.txt,保存模型(最好和最后)。

    for epoch in range(start_epoch, epochs):  # epoch ------------------------------------------------------------------
        model.train()
    <span class="token comment"># Update image weights (optional)</span>
    <span class="token keyword">if</span> dataset<span class="token punctuation">.</span>image_weights<span class="token punctuation">:</span>
        w <span class="token operator">=</span> model<span class="token punctuation">.</span>class_weights<span class="token punctuation">.</span>cpu<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>numpy<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">*</span> <span class="token punctuation">(</span><span class="token number">1</span> <span class="token operator">-</span> maps<span class="token punctuation">)</span> <span class="token operator">**</span> <span class="token number">2</span>  <span class="token comment"># class weights</span>
        image_weights <span class="token operator">=</span> labels_to_image_weights<span class="token punctuation">(</span>dataset<span class="token punctuation">.</span>labels<span class="token punctuation">,</span> nc<span class="token operator">=</span>nc<span class="token punctuation">,</span> class_weights<span class="token operator">=</span>w<span class="token punctuation">)</span>
        dataset<span class="token punctuation">.</span>indices <span class="token operator">=</span> random<span class="token punctuation">.</span>choices<span class="token punctuation">(</span><span class="token builtin">range</span><span class="token punctuation">(</span>dataset<span class="token punctuation">.</span>n<span class="token punctuation">)</span><span class="token punctuation">,</span> weights<span class="token operator">=</span>image_weights<span class="token punctuation">,</span> k<span class="token operator">=</span>dataset<span class="token punctuation">.</span>n<span class="token punctuation">)</span>  <span class="token comment"># rand weighted idx</span>

    <span class="token comment"># Update mosaic border</span>
    <span class="token comment"># b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)</span>
    <span class="token comment"># dataset.mosaic_border = [b - imgsz, -b]  # height, width borders</span>

    mloss <span class="token operator">=</span> torch<span class="token punctuation">.</span>zeros<span class="token punctuation">(</span><span class="token number">4</span><span class="token punctuation">,</span> device<span class="token operator">=</span>device<span class="token punctuation">)</span>  <span class="token comment"># mean losses</span>
    <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token punctuation">(</span><span class="token string">'\n'</span> <span class="token operator">+</span> <span class="token string">'%10s'</span> <span class="token operator">*</span> <span class="token number">8</span><span class="token punctuation">)</span> <span class="token operator">%</span> <span class="token punctuation">(</span><span class="token string">'Epoch'</span><span class="token punctuation">,</span> <span class="token string">'gpu_mem'</span><span class="token punctuation">,</span> <span class="token string">'GIoU'</span><span class="token punctuation">,</span> <span class="token string">'obj'</span><span class="token punctuation">,</span> <span class="token string">'cls'</span><span class="token punctuation">,</span> <span class="token string">'total'</span><span class="token punctuation">,</span> <span class="token string">'targets'</span><span class="token punctuation">,</span> <span class="token string">'img_size'</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
    pbar <span class="token operator">=</span> tqdm<span class="token punctuation">(</span><span class="token builtin">enumerate</span><span class="token punctuation">(</span>dataloader<span class="token punctuation">)</span><span class="token punctuation">,</span> total<span class="token operator">=</span>nb<span class="token punctuation">)</span>  <span class="token comment"># progress bar</span>
    <span class="token keyword">for</span> i<span class="token punctuation">,</span> <span class="token punctuation">(</span>imgs<span class="token punctuation">,</span> targets<span class="token punctuation">,</span> paths<span class="token punctuation">,</span> _<span class="token punctuation">)</span> <span class="token keyword">in</span> pbar<span class="token punctuation">:</span>  <span class="token comment"># batch -------------------------------------------------------------</span>
        ni <span class="token operator">=</span> i <span class="token operator">+</span> nb <span class="token operator">*</span> epoch  <span class="token comment"># number integrated batches (since train start)</span>
        imgs <span class="token operator">=</span> imgs<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">float</span><span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token number">255.0</span>  <span class="token comment"># uint8 to float32, 0 - 255 to 0.0 - 1.0</span>

        <span class="token comment"># Burn-in</span>
        <span class="token keyword">if</span> ni <span class="token operator">&lt;=</span> n_burn<span class="token punctuation">:</span>
            xi <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">,</span> n_burn<span class="token punctuation">]</span>  <span class="token comment"># x interp</span>
            <span class="token comment"># model.gr = np.interp(ni, xi, [0.0, 1.0])  # giou loss ratio (obj_loss = 1.0 or giou)</span>
            accumulate <span class="token operator">=</span> <span class="token builtin">max</span><span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> np<span class="token punctuation">.</span>interp<span class="token punctuation">(</span>ni<span class="token punctuation">,</span> xi<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">1</span><span class="token punctuation">,</span> nbs <span class="token operator">/</span> batch_size<span class="token punctuation">]</span><span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">round</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
            <span class="token keyword">for</span> j<span class="token punctuation">,</span> x <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>optimizer<span class="token punctuation">.</span>param_groups<span class="token punctuation">)</span><span class="token punctuation">:</span>
                <span class="token comment"># bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0</span>
                x<span class="token punctuation">[</span><span class="token string">'lr'</span><span class="token punctuation">]</span> <span class="token operator">=</span> np<span class="token punctuation">.</span>interp<span class="token punctuation">(</span>ni<span class="token punctuation">,</span> xi<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">0.1</span> <span class="token keyword">if</span> j <span class="token operator">==</span> <span class="token number">2</span> <span class="token keyword">else</span> <span class="token number">0.0</span><span class="token punctuation">,</span> x<span class="token punctuation">[</span><span class="token string">'initial_lr'</span><span class="token punctuation">]</span> <span class="token operator">*</span> lf<span class="token punctuation">(</span>epoch<span class="token punctuation">)</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
                <span class="token keyword">if</span> <span class="token string">'momentum'</span> <span class="token keyword">in</span> x<span class="token punctuation">:</span>
                    x<span class="token punctuation">[</span><span class="token string">'momentum'</span><span class="token punctuation">]</span> <span class="token operator">=</span> np<span class="token punctuation">.</span>interp<span class="token punctuation">(</span>ni<span class="token punctuation">,</span> xi<span class="token punctuation">,</span> <span class="token punctuation">[</span><span class="token number">0.9</span><span class="token punctuation">,</span> hyp<span class="token punctuation">[</span><span class="token string">'momentum'</span><span class="token punctuation">]</span><span class="token punctuation">]</span><span class="token punctuation">)</span>

        <span class="token comment"># Multi-scale</span>
        <span class="token keyword">if</span> opt<span class="token punctuation">.</span>multi_scale<span class="token punctuation">:</span>
            sz <span class="token operator">=</span> random<span class="token punctuation">.</span>randrange<span class="token punctuation">(</span>imgsz <span class="token operator">*</span> <span class="token number">0.5</span><span class="token punctuation">,</span> imgsz <span class="token operator">*</span> <span class="token number">1.5</span> <span class="token operator">+</span> gs<span class="token punctuation">)</span> <span class="token operator">//</span> gs <span class="token operator">*</span> gs  <span class="token comment"># size</span>
            sf <span class="token operator">=</span> sz <span class="token operator">/</span> <span class="token builtin">max</span><span class="token punctuation">(</span>imgs<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">)</span>  <span class="token comment"># scale factor</span>
            <span class="token keyword">if</span> sf <span class="token operator">!=</span> <span class="token number">1</span><span class="token punctuation">:</span>
                ns <span class="token operator">=</span> <span class="token punctuation">[</span>math<span class="token punctuation">.</span>ceil<span class="token punctuation">(</span>x <span class="token operator">*</span> sf <span class="token operator">/</span> gs<span class="token punctuation">)</span> <span class="token operator">*</span> gs <span class="token keyword">for</span> x <span class="token keyword">in</span> imgs<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">2</span><span class="token punctuation">:</span><span class="token punctuation">]</span><span class="token punctuation">]</span>  <span class="token comment"># new shape (stretched to gs-multiple)</span>
                imgs <span class="token operator">=</span> F<span class="token punctuation">.</span>interpolate<span class="token punctuation">(</span>imgs<span class="token punctuation">,</span> size<span class="token operator">=</span>ns<span class="token punctuation">,</span> mode<span class="token operator">=</span><span class="token string">'bilinear'</span><span class="token punctuation">,</span> align_corners<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>

        <span class="token comment"># Forward</span>
        pred <span class="token operator">=</span> model<span class="token punctuation">(</span>imgs<span class="token punctuation">)</span>

        <span class="token comment"># Loss</span>
        loss<span class="token punctuation">,</span> loss_items <span class="token operator">=</span> compute_loss<span class="token punctuation">(</span>pred<span class="token punctuation">,</span> targets<span class="token punctuation">.</span>to<span class="token punctuation">(</span>device<span class="token punctuation">)</span><span class="token punctuation">,</span> model<span class="token punctuation">)</span>
        <span class="token keyword">if</span> <span class="token operator">not</span> torch<span class="token punctuation">.</span>isfinite<span class="token punctuation">(</span>loss<span class="token punctuation">)</span><span class="token punctuation">:</span>
            <span class="token keyword">print</span><span class="token punctuation">(</span><span class="token string">'WARNING: non-finite loss, ending training '</span><span class="token punctuation">,</span> loss_items<span class="token punctuation">)</span>
            <span class="token keyword">return</span> results

        <span class="token comment"># Backward</span>
        <span class="token keyword">if</span> mixed_precision<span class="token punctuation">:</span>
            <span class="token keyword">with</span> amp<span class="token punctuation">.</span>scale_loss<span class="token punctuation">(</span>loss<span class="token punctuation">,</span> optimizer<span class="token punctuation">)</span> <span class="token keyword">as</span> scaled_loss<span class="token punctuation">:</span>
                scaled_loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
        <span class="token keyword">else</span><span class="token punctuation">:</span>
            loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span>

        <span class="token comment"># Optimize</span>
        <span class="token keyword">if</span> ni <span class="token operator">%</span> accumulate <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
            optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span>
            optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
            ema<span class="token punctuation">.</span>update<span class="token punctuation">(</span>model<span class="token punctuation">)</span>

        <span class="token comment"># Print</span>
        mloss <span class="token operator">=</span> <span class="token punctuation">(</span>mloss <span class="token operator">*</span> i <span class="token operator">+</span> loss_items<span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token punctuation">(</span>i <span class="token operator">+</span> <span class="token number">1</span><span class="token punctuation">)</span>  <span class="token comment"># update mean losses</span>
        mem <span class="token operator">=</span> <span class="token string">'%.3gG'</span> <span class="token operator">%</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>memory_cached<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token operator">/</span> <span class="token number">1E9</span> <span class="token keyword">if</span> torch<span class="token punctuation">.</span>cuda<span class="token punctuation">.</span>is_available<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token keyword">else</span> <span class="token number">0</span><span class="token punctuation">)</span>  <span class="token comment"># (GB)</span>
        s <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token string">'%10s'</span> <span class="token operator">*</span> <span class="token number">2</span> <span class="token operator">+</span> <span class="token string">'%10.4g'</span> <span class="token operator">*</span> <span class="token number">6</span><span class="token punctuation">)</span> <span class="token operator">%</span> <span class="token punctuation">(</span>
            <span class="token string">'%g/%g'</span> <span class="token operator">%</span> <span class="token punctuation">(</span>epoch<span class="token punctuation">,</span> epochs <span class="token operator">-</span> <span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">,</span> mem<span class="token punctuation">,</span> <span class="token operator">*</span>mloss<span class="token punctuation">,</span> targets<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span><span class="token punctuation">,</span> imgs<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span>
        pbar<span class="token punctuation">.</span>set_description<span class="token punctuation">(</span>s<span class="token punctuation">)</span>

        <span class="token comment"># Plot</span>
        <span class="token keyword">if</span> ni <span class="token operator">&lt;</span> <span class="token number">3</span><span class="token punctuation">:</span>
            f <span class="token operator">=</span> <span class="token string">'train_batch%g.jpg'</span> <span class="token operator">%</span> ni  <span class="token comment"># filename</span>
            result <span class="token operator">=</span> plot_images<span class="token punctuation">(</span>images<span class="token operator">=</span>imgs<span class="token punctuation">,</span> targets<span class="token operator">=</span>targets<span class="token punctuation">,</span> paths<span class="token operator">=</span>paths<span class="token punctuation">,</span> fname<span class="token operator">=</span>f<span class="token punctuation">)</span>
            <span class="token keyword">if</span> tb_writer <span class="token operator">and</span> result <span class="token keyword">is</span> <span class="token operator">not</span> <span class="token boolean">None</span><span class="token punctuation">:</span>
                tb_writer<span class="token punctuation">.</span>add_image<span class="token punctuation">(</span>f<span class="token punctuation">,</span> result<span class="token punctuation">,</span> dataformats<span class="token operator">=</span><span class="token string">'HWC'</span><span class="token punctuation">,</span> global_step<span class="token operator">=</span>epoch<span class="token punctuation">)</span>
                <span class="token comment"># tb_writer.add_graph(model, imgs)  # add model to tensorboard</span>

        <span class="token comment"># end batch ------------------------------------------------------------------------------------------------</span>

    <span class="token comment"># Scheduler</span>
    scheduler<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span>

    <span class="token comment"># mAP</span>
    ema<span class="token punctuation">.</span>update_attr<span class="token punctuation">(</span>model<span class="token punctuation">)</span>
    final_epoch <span class="token operator">=</span> epoch <span class="token operator">+</span> <span class="token number">1</span> <span class="token operator">==</span> epochs
    <span class="token keyword">if</span> <span class="token operator">not</span> opt<span class="token punctuation">.</span>notest <span class="token operator">or</span> final_epoch<span class="token punctuation">:</span>  <span class="token comment"># Calculate mAP</span>
        results<span class="token punctuation">,</span> maps<span class="token punctuation">,</span> times <span class="token operator">=</span> test<span class="token punctuation">.</span>test<span class="token punctuation">(</span>opt<span class="token punctuation">.</span>data<span class="token punctuation">,</span>
                                         batch_size<span class="token operator">=</span>batch_size<span class="token punctuation">,</span>
                                         imgsz<span class="token operator">=</span>imgsz_test<span class="token punctuation">,</span>
                                         save_json<span class="token operator">=</span>final_epoch <span class="token operator">and</span> opt<span class="token punctuation">.</span>data<span class="token punctuation">.</span>endswith<span class="token punctuation">(</span>os<span class="token punctuation">.</span>sep <span class="token operator">+</span> <span class="token string">'coco.yaml'</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                                         model<span class="token operator">=</span>ema<span class="token punctuation">.</span>ema<span class="token punctuation">,</span>
                                         single_cls<span class="token operator">=</span>opt<span class="token punctuation">.</span>single_cls<span class="token punctuation">,</span>
                                         dataloader<span class="token operator">=</span>testloader<span class="token punctuation">)</span>

    <span class="token comment"># Write</span>
    <span class="token keyword">with</span> <span class="token builtin">open</span><span class="token punctuation">(</span>results_file<span class="token punctuation">,</span> <span class="token string">'a'</span><span class="token punctuation">)</span> <span class="token keyword">as</span> f<span class="token punctuation">:</span>
        f<span class="token punctuation">.</span>write<span class="token punctuation">(</span>s <span class="token operator">+</span> <span class="token string">'%10.4g'</span> <span class="token operator">*</span> <span class="token number">7</span> <span class="token operator">%</span> results <span class="token operator">+</span> <span class="token string">'\n'</span><span class="token punctuation">)</span>  <span class="token comment"># P, R, mAP, F1, test_losses=(GIoU, obj, cls)</span>
    <span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>opt<span class="token punctuation">.</span>name<span class="token punctuation">)</span> <span class="token operator">and</span> opt<span class="token punctuation">.</span>bucket<span class="token punctuation">:</span>
        os<span class="token punctuation">.</span>system<span class="token punctuation">(</span><span class="token string">'gsutil cp results.txt gs://%s/results/results%s.txt'</span> <span class="token operator">%</span> <span class="token punctuation">(</span>opt<span class="token punctuation">.</span>bucket<span class="token punctuation">,</span> opt<span class="token punctuation">.</span>name<span class="token punctuation">)</span><span class="token punctuation">)</span>

    <span class="token comment"># Tensorboard</span>
    <span class="token keyword">if</span> tb_writer<span class="token punctuation">:</span>
        tags <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token string">'train/giou_loss'</span><span class="token punctuation">,</span> <span class="token string">'train/obj_loss'</span><span class="token punctuation">,</span> <span class="token string">'train/cls_loss'</span><span class="token punctuation">,</span>
                <span class="token string">'metrics/precision'</span><span class="token punctuation">,</span> <span class="token string">'metrics/recall'</span><span class="token punctuation">,</span> <span class="token string">'metrics/mAP_0.5'</span><span class="token punctuation">,</span> <span class="token string">'metrics/F1'</span><span class="token punctuation">,</span>
                <span class="token string">'val/giou_loss'</span><span class="token punctuation">,</span> <span class="token string">'val/obj_loss'</span><span class="token punctuation">,</span> <span class="token string">'val/cls_loss'</span><span class="token punctuation">]</span>
        <span class="token keyword">for</span> x<span class="token punctuation">,</span> tag <span class="token keyword">in</span> <span class="token builtin">zip</span><span class="token punctuation">(</span><span class="token builtin">list</span><span class="token punctuation">(</span>mloss<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">+</span> <span class="token builtin">list</span><span class="token punctuation">(</span>results<span class="token punctuation">)</span><span class="token punctuation">,</span> tags<span class="token punctuation">)</span><span class="token punctuation">:</span>
            tb_writer<span class="token punctuation">.</span>add_scalar<span class="token punctuation">(</span>tag<span class="token punctuation">,</span> x<span class="token punctuation">,</span> epoch<span class="token punctuation">)</span>

    <span class="token comment"># Update best mAP</span>
    fi <span class="token operator">=</span> fitness<span class="token punctuation">(</span>np<span class="token punctuation">.</span>array<span class="token punctuation">(</span>results<span class="token punctuation">)</span><span class="token punctuation">.</span>reshape<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">)</span>  <span class="token comment"># fitness_i = weighted combination of [P, R, mAP, F1]</span>
    <span class="token keyword">if</span> fi <span class="token operator">&gt;</span> best_fitness<span class="token punctuation">:</span>
        best_fitness <span class="token operator">=</span> fi

    <span class="token comment"># Save model</span>
    save <span class="token operator">=</span> <span class="token punctuation">(</span><span class="token operator">not</span> opt<span class="token punctuation">.</span>nosave<span class="token punctuation">)</span> <span class="token operator">or</span> <span class="token punctuation">(</span>final_epoch <span class="token operator">and</span> <span class="token operator">not</span> opt<span class="token punctuation">.</span>evolve<span class="token punctuation">)</span>
    <span class="token keyword">if</span> save<span class="token punctuation">:</span>
        <span class="token keyword">with</span> <span class="token builtin">open</span><span class="token punctuation">(</span>results_file<span class="token punctuation">,</span> <span class="token string">'r'</span><span class="token punctuation">)</span> <span class="token keyword">as</span> f<span class="token punctuation">:</span>  <span class="token comment"># create checkpoint</span>
            ckpt <span class="token operator">=</span> <span class="token punctuation">{<!-- --></span><span class="token string">'epoch'</span><span class="token punctuation">:</span> epoch<span class="token punctuation">,</span>
                    <span class="token string">'best_fitness'</span><span class="token punctuation">:</span> best_fitness<span class="token punctuation">,</span>
                    <span class="token string">'training_results'</span><span class="token punctuation">:</span> f<span class="token punctuation">.</span>read<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                    <span class="token string">'model'</span><span class="token punctuation">:</span> ema<span class="token punctuation">.</span>ema<span class="token punctuation">.</span>module <span class="token keyword">if</span> <span class="token builtin">hasattr</span><span class="token punctuation">(</span>model<span class="token punctuation">,</span> <span class="token string">'module'</span><span class="token punctuation">)</span> <span class="token keyword">else</span> ema<span class="token punctuation">.</span>ema<span class="token punctuation">,</span>
                    <span class="token string">'optimizer'</span><span class="token punctuation">:</span> <span class="token boolean">None</span> <span class="token keyword">if</span> final_epoch <span class="token keyword">else</span> optimizer<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">}</span>

        <span class="token comment"># Save last, best and delete</span>
        torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span>ckpt<span class="token punctuation">,</span> last<span class="token punctuation">)</span>
        <span class="token keyword">if</span> <span class="token punctuation">(</span>best_fitness <span class="token operator">==</span> fi<span class="token punctuation">)</span> <span class="token operator">and</span> <span class="token operator">not</span> final_epoch<span class="token punctuation">:</span>
            torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span>ckpt<span class="token punctuation">,</span> best<span class="token punctuation">)</span>
        <span class="token keyword">del</span> ckpt

    <span class="token comment"># end epoch ----------------------------------------------------------------------------------------------------</span>
<span class="token comment"># end training</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130

Image sizes 608 train, 608 test(设置训练和测试图片的size)
Using 8 dataloader workers(设置batch size 为8,即一次性输入8张图片训练)
Starting training for 100 epochs… (设置为100个epochs)
——————————————————————————————————————
tqdm是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。
参考博客
tqdm进度条
python pbar = tqdm(enumerate(dataloader), total=nb) 表示进度条,total=nb 预期的迭代次数,即你上面设置的epochs。
——————————————————————————————————————
results.txt保存结果:
0/49 6.44G 0.09249 0.07952 0.05631 0.2283 6 608 0.1107 0.1954 0.1029 0.03088 0.07504 0.06971 0.03865
epoch, best_fitness, training_results, model, optimizer, img-size, P, R, mAP, F1, test_losses=(GIoU, obj, cls)
(有点对不上,后续再补充)

1.18 定义模型文件名字

    n = opt.name
    if len(n):
        n = '_' + n if not n.isnumeric() else n
        fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
        for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'], [flast, fbest, fresults]):
            if os.path.exists(f1):
                os.rename(f1, f2)  # rename
                ispt = f2.endswith('.pt')  # is *.pt
                strip_optimizer(f2) if ispt else None  # strip optimizer
                os.system('gsutil cp %s gs://%s/weights' % (f2, opt.bucket)) if opt.bucket and ispt else None  # upload

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

1.19 训练结束,返回结果

    if not opt.evolve:
        plot_results()  # save as results.png
    print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
    dist.destroy_process_group() if device.type != 'cpu' and torch.cuda.device_count() > 1 else None
    torch.cuda.empty_cache()
    return results

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

50 epochs completed in 11.954 hours.
在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值