由于transformer模型参数量巨大,数据集也巨大,所以对显卡需求越来越大,单卡训练非常的耗费时间。实验室还有不少显卡但是不会用多卡就很糟心,所以得把多卡用上。多卡用到的库有不少,最受欢迎的应该是DP和DDP,但是DP只能解决显存不足的问题,并不能减少时间,所以DDP采用的更多。说到单机多卡,网上的教程倒是不少,原理解析的也挺明白,所以废话留在后头,直接来一个DDP的单机多卡通用模板。在自己测试过后,单卡一个epoch为8小时=4卡2小时,还是非常的方便。
# 最重要的模块argparse, distributed这俩是多卡训练的必要
import argparse
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torch
import torch.distributed as dist
# 不用管为啥直接抄上这5行
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', default=-1, type=int)
args = parser.parse_args()
dist.init_process_group(backend='nccl', world_size="有几张卡填几", rank=args.local_rank)
torch.cuda.set_device(args.local_rank)
# 正常把模型加载进来之后, 直接把第二行抄上, 抄完第二行之后,下面但凡遇到model. 都要改成model.module.
model = ......
model = torch.nn.parallel.DistributedDataParallel(model.cuda(args.local_rank), device_ids=[args.local_rank])
#构建数据集之后不要直接dataloader, 中间加一行sampler再dataloader
train_dataset = ......
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=..., sampler=train_sampler)
val_dataset = ......
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
val_dataloader = DataLoader(val_dataset, batch_size=..., sampler=val_sampler)
# 训练代码中大概这种格式
for epoch in range(epochs):
......
# 每张卡取到的数据应该是总数据集的 1/n,要是除不开自动补全, 这行得到了每张卡的数据个数,其实没有必要写这一行,对整体影响几乎为0
each_dist_train_data_num = ((len(train_dataset) % dist.get_world_size()) + len(train_dataset)) / dist.get_world_size()
# 开始取数据之前把这行加上
train_sampler.set_epoch(epoch)
for i, data in enumerate(train_dataloader):
# 数据加载进cuda和单机代码不一样
data = data.cuda(args.local_rank)
outputs = model(data)
optimizer.zero_grad()
loss = ......
loss.backward()
optimizer.step()
# 折腾完把loss和acc记录下来
train_loss =......
train_acc = .........
# 预测和训练一样的
with torch.no_grad():
model.eval()
each_dist_val_data_num = ((len(val_dataset) % dist.get_world_size()) + len(val_dataset)) / dist.get_world_size()
val_sampler.set_epoch(epoch)
for i, (text, image, label) in enumerate(val_dataloader):
data = data.cuda(args.local_rank)
outputs = model(data)
optimizer.zero_grad()
loss = ......
loss.backward()
optimizer.step()
# 折腾完把loss和acc记录下来
val_loss =......
val_acc = .........
# 开始算平均损失和准确率
# 1. 每张卡的数据要汇总,所以多了一步汇总操作,该操作要求输入数据为tensor
train_loss = torch.tensor(train_loss, dtype=torch.float).cuda(args.local_rank)
# 2. 损失值应当为每张卡损失之和的平均
"""
先算单卡的平均损失:单卡总损失 / 单卡计算的数据集
然后把多卡的损失值加在一起:dist.all_reduce(单卡平均损失)
最后算多卡平均损失:多卡总损失 / 卡数
"""
avg_train_loss = train_loss / each_dist_train_data_num
dist.all_reduce(avg_train_loss)
avg_train_loss = avg_train_loss / dist.get_world_size()
# 3. 准确率就是把所有卡正确的个数加在一起 / 总数据个数, 这个方法更简便(不完全准确但可忽略)
train_acc = torch.tensor(train_acc, dtype=torch.float).cuda(args.local_rank)
dist.all_reduce(train_acc)
avg_train_acc = train_acc / len(train_dataset) * 100
val_loss = torch.tensor(val_loss, dtype=torch.float).cuda(args.local_rank)
avg_valid_loss = val_loss / each_dist_val_data_num
dist.all_reduce(avg_valid_loss)
avg_valid_loss = avg_valid_loss / dist.get_world_size()
val_acc = torch.tensor(val_acc, dtype=torch.float).cuda(args.local_rank)
dist.all_reduce(val_acc)
avg_valid_acc = val_acc / len(val_dataset) * 100
# 多卡训练是多进程并行训练, 模型在0卡存一个进行了,包括打印和日志都需要加上if dist.get_rank() == 0: 保证只打印一遍
if dist.get_rank() == 0:
model_state_dict = model.module.state_dict()
torch.save(model_state_dict, "model.pth")
这是大概的模板,运行需要终端运行,输入:nproc_per_node有几张卡输入几
python -m torch.distributed.launch --nproc_per_node=int xxx.py
下面是注意事项:
1. model = torch.nn.parallel.DistributedDataParallel()代码之后,所有的model类方法都会变成model.module的类方法。否则会报错:model没有xxx方法。
2. data.cuda()的时候要添加参数data.cuda(args.local_rank),保证不同数据分配到不同的gpu上。
3. 数据集DistributedSampler方法代替了shuffle=True,在训练载入数据前要添加train_sampler.set_epoch(epoch)防止loss均匀下降。
4. 多gpu虽然是多进程,但是会共享同一份数据集,batch_size设置还是为单卡的最大batch_size,不是有几张卡就成多少倍。
5. 多gpu共同训练一份数据集,如果数据集长度不能被gpu个数整除,会自动补上数据计算,实际训练数据集个数应为数据集长度+ 不能整除的个数(单机最多7个),所以对大规模数据集来讲基本可以忽略不计,算loss和acc也可以直接忽略。
6. 损失值应当为每张卡损失之和的平均值,好好算算这个逻辑,最后打印出来的loss和acc应该与正常单卡训练相差不大,而不是成卡的数量倍数变化的。
7. 由于是多进程运行,一行打印代码可能会打印好几次,所以只在主进程中打印,加一句判断条件if dist.get_rank() == 0: 再进行日志输出和模型保存。