1. Pytorch分布式训练
Pytorch支持多机多卡分布式训练,参与分布式训练的机器用Node表述(Node不限定是物理机器,还是容器,例如docker,一个Node节点就是一台机器),Node又分为Master Node、Slave Node,Master Node只有一个,Slave Node可以有多个,假定现在有两台机器参与分布式训练,每台机器有4张显卡,分别在两台机器上执行如下命令(以yolov5训练为例):
Master Node执行如下命令:
python -m torch.distributed.launch \
--nnodes 2 \
--nproc_per_node 4 \
--use_env \
--node_rank 0 \
--master_addr "192.168.1.2" \
--master_port 1234 \
train.py \
--batch 64 \
--data coco.yaml \
--cfg yolov5s.yaml \
--weights 'yolov5s.pt'
Slave Node执行如下命令:
python -m torch.distributed.launch \
--nnodes 2 \
--nproc_per_node 4 \
--use_env \
--node_rank 1 \
--master_addr "192.168.1.2" \
--master_port 1234 train.py \
--batch 64 \
--data coco.yaml \
--cfg yolov5s.yaml \
--weights 'yolov5s.pt
上述的命令中:
--nnodes:表示一共多少台机器参与分布式训练,也就是有几个Node,只有2台机器所以设置为2
--nproc_per_node:表示每台机器有多少张显卡,每台机器有4张显卡所以设置为4
--node_rank:表示当前机器的序号,一般设置为0的作为Master Node
--master_add:表示Master Node的ip地址
--master_add:表示Master Node的端口号
一般情况下,当两台机器都运行完命令,训练就开始了,否则master处于等待状态,直到slave节点也就绪,分布式训练才会开始。
如下所示,训练开始后Node节点的分配情况:
Node 0
Process0 [Global Rank=0, Local Rank=0] -> GPU 0
Process1 [Global Rank=1, Local Rank=1] -> GPU 1
Process2 [Global Rank=2, Local Rank=2] -> GPU 2
Process3 [Global Rank=3, Local Rank=3] -> GPU 3
Node 1
Process4 [Global Rank=4, Local Rank=0] -> GPU 0
Process5 [Global Rank=5, Local Rank=1] -> GPU 1
Process6 [Global Rank=6, Local Rank=2] -> GPU 2
Process7 [Global Rank=7, Local Rank=3] -> GPU 3
每一张显卡被分配一个进程,从Process0 ~ Process7,Global Rank表示在整个分布式训练任务中的分布式进程编号,从Global Rank = 0 ~ Global Rank = 7
Local Rank表示在某个Node内的编号,在Node 0中Local Rank = 0 ~ Local Rank = 3。另外,如果只写rank前面没有global、local等字段一般指代Global Rank。
world_size表示全局进程数量,也就是分布式进程的数量,在上述的配置中world_size = 8,如果一共有3个node(nnodes=3),每个node包含8个GPU,设置nproc_per_node=4,world_size就是3 * 4 = 12,为什么不是3 * 8 = 24呢?因为每个node虽然有8个GPU,但是命令设置只使用其中4个(nproc_per_node=4),有而不使用是不算数的。
训练算法的时候如果要使用分布式训练,需要对训练流程添加分布式的支持,主要有如下步骤:
1) 初始化分布式进程环境;
2)对数据集构建分布式采样器;
3)对网络模型用DistributedDataParallel进行包装;
4)日志与模型保存在主进程中进行;
5)对loss、评估指标等数据进行all_reduce同步;
6)如果网络中存在BN层,可开启BN同步
7)如果多机进行分布式训练,需要保证Node直接网络互通。
更多分布式训练的知识可参考:(471条消息) Pytorch中多GPU并行计算教程_太阳花的小绿豆的博客-CSDN博客_pytorch 多gpu
2. 断点续训
断点续训比较简单,在训练的过程中需要在checkpoints中保存能够回复模型训练的数据,主要包括:模型权重参数、优化器、学习率调度器、当前训练的轮次(比如epoch数),另外,也可以保存评估指标、训练参数等数据。在恢复训练的时候将上述数据从checkpoint中取出,从当前状态继续训练。
3. 分布式训练与断点续训代码示例
dist_classification_train.py
import argparse
import os
import numpy as np
import random
import tempfile
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets
from torchvision import models
import torchvision.transforms as transforms
import torch.distributed as dist
from torch.utils.tensorboard impor