Pytorch分布式训练与断点续训

本文介绍了PyTorch的分布式训练,包括Master Node和Slave Node的概念,以及如何配置和启动分布式训练。内容涵盖world_size、Local Rank和Global Rank的解释。此外,还讨论了断点续训的实现,强调保存模型、优化器和训练状态的重要性。最后,提供了一个分布式训练和断点续训的代码示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1. Pytorch分布式训练

Pytorch支持多机多卡分布式训练,参与分布式训练的机器用Node表述(Node不限定是物理机器,还是容器,例如docker,一个Node节点就是一台机器),Node又分为Master Node、Slave Node,Master Node只有一个,Slave Node可以有多个,假定现在有两台机器参与分布式训练,每台机器有4张显卡,分别在两台机器上执行如下命令(以yolov5训练为例):

Master Node执行如下命令:

python -m torch.distributed.launch \
       --nnodes 2 \
       --nproc_per_node 4 \
       --use_env \
       --node_rank 0 \
       --master_addr "192.168.1.2" \
       --master_port 1234 \
       train.py \
       --batch 64 \
       --data coco.yaml \
       --cfg yolov5s.yaml \
       --weights 'yolov5s.pt'

 Slave Node执行如下命令:

python -m torch.distributed.launch \
       --nnodes 2 \
       --nproc_per_node 4 \        
       --use_env \
       --node_rank 1 \
       --master_addr "192.168.1.2" \
       --master_port 1234 train.py \
       --batch 64 \
       --data coco.yaml \
       --cfg yolov5s.yaml \
       --weights 'yolov5s.pt

上述的命令中:

--nnodes:表示一共多少台机器参与分布式训练,也就是有几个Node,只有2台机器所以设置为2
--nproc_per_node:表示每台机器有多少张显卡,每台机器有4张显卡所以设置为4
--node_rank:表示当前机器的序号,一般设置为0的作为Master Node
--master_add:表示Master Node的ip地址
--master_add:表示Master Node的端口号


一般情况下,当两台机器都运行完命令,训练就开始了,否则master处于等待状态,直到slave节点也就绪,分布式训练才会开始。

如下所示,训练开始后Node节点的分配情况:

Node 0
    Process0 [Global Rank=0, Local Rank=0] -> GPU 0
    Process1 [Global Rank=1, Local Rank=1] -> GPU 1
    Process2 [Global Rank=2, Local Rank=2] -> GPU 2
    Process3 [Global Rank=3, Local Rank=3] -> GPU 3
Node 1
    Process4 [Global Rank=4, Local Rank=0] -> GPU 0
    Process5 [Global Rank=5, Local Rank=1] -> GPU 1
    Process6 [Global Rank=6, Local Rank=2] -> GPU 2
    Process7 [Global Rank=7, Local Rank=3] -> GPU 3

每一张显卡被分配一个进程,从Process0 ~ Process7,Global Rank表示在整个分布式训练任务中的分布式进程编号,从Global Rank = 0 ~ Global Rank = 7

Local Rank表示在某个Node内的编号,在Node 0中Local Rank = 0 ~ Local Rank = 3。另外,如果只写rank前面没有global、local等字段一般指代Global Rank。

world_size表示全局进程数量,也就是分布式进程的数量,在上述的配置中world_size = 8,如果一共有3个node(nnodes=3),每个node包含8个GPU,设置nproc_per_node=4,world_size就是3 * 4 = 12,为什么不是3 * 8 = 24呢?因为每个node虽然有8个GPU,但是命令设置只使用其中4个(nproc_per_node=4),有而不使用是不算数的。
 

训练算法的时候如果要使用分布式训练,需要对训练流程添加分布式的支持,主要有如下步骤:

1) 初始化分布式进程环境;

2)对数据集构建分布式采样器;

3)对网络模型用DistributedDataParallel进行包装;

4)日志与模型保存在主进程中进行;

5)对loss、评估指标等数据进行all_reduce同步;

6)如果网络中存在BN层,可开启BN同步

7)如果多机进行分布式训练,需要保证Node直接网络互通。

更多分布式训练的知识可参考:(471条消息) Pytorch中多GPU并行计算教程_太阳花的小绿豆的博客-CSDN博客_pytorch 多gpu

2. 断点续训

断点续训比较简单,在训练的过程中需要在checkpoints中保存能够回复模型训练的数据,主要包括:模型权重参数、优化器、学习率调度器、当前训练的轮次(比如epoch数),另外,也可以保存评估指标、训练参数等数据。在恢复训练的时候将上述数据从checkpoint中取出,从当前状态继续训练。

3. 分布式训练与断点续训代码示例

dist_classification_train.py

import argparse
import os
import numpy as np
import random
import tempfile
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets
from torchvision import models
import torchvision.transforms as transforms
import torch.distributed as dist
from torch.utils.tensorboard impor
mmdetection是一个基于PyTorch的开源目标检测框架,支持多种经典模型和数据集。下面是mmdetection进行断点续训的一般步骤: 1. 在训练过程中,设置合适的`checkpoint_config`参数,以便定期保存模型参数和优化器状态。其中包括: - `interval`:保存checkpoint的间隔epoch数; - `save_optimizer`:是否同时保存优化器状态; - `type`:checkpoint保存的格式,如`epoch`、`iteration`等; - `max_to_keep`:保存的最大checkpoint数目。 例如,在训练Faster R-CNN模型时,可以设置如下的checkpoint_config: ```python checkpoint_config = dict(interval=1, save_optimizer=True, type='epoch', max_to_keep=5) ``` 2. 当训练过程中断时,需要手动加载之前保存的checkpoint。可以使用`mmdet.core.checkpoint.load_checkpoint()`函数加载checkpoint,然后将返回的checkpoint字典中的参数和优化器状态加载到当前模型中。 例如,假设我们之前保存了一个名为`epoch_10.pth`的checkpoint,可以使用以下代码加载checkpoint: ```python from mmdet.core import checkpoint checkpoint_file = 'epoch_10.pth' checkpoint_dict = checkpoint.load_checkpoint(model, checkpoint_file, map_location='cpu') optimizer.load_state_dict(checkpoint_dict['optimizer']) start_epoch = checkpoint_dict['epoch'] + 1 ``` 其中,`model`是当前的模型,`optimizer`是当前的优化器,`map_location`是指定保存checkpoint时使用的设备,如`'cpu'`或`'cuda:0'`等。 3. 基于加载的checkpoint,从上次训练结束的epoch开始,继续训练模型。需要注意的是,由于在加载checkpoint时已经将当前模型的参数和优化器状态设置为上次训练结束的状态,因此可以直接调用训练函数进行训练,无需重新初始化模型和优化器。 例如,在使用`train_detector()`函数进行训练时,可以设置`start_epoch`参数为上次训练结束的epoch,然后继续训练模型: ```python from mmdet.apis import train_detector # continue training from last epoch epochs = 10 train_detector(model, dataset, cfg, distributed=False, validate=True, start_epoch=start_epoch, epochs=epochs) ``` 需要注意的是,由于mmdetection支持分布式训练,因此在进行断点续训时需要根据当前的训练方式(单机多卡、多机多卡等)和之前保存checkpoint时的设置进行相应的调整。同时,在分布式训练中,需要保证所有节点使用相同的checkpoint进行训练,以避免模型不同步的问题。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

洪流之源

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值