DistributedDataParallel, 多进程,支持数据并行、模型并行,支持单机多卡、多机多卡;进程间仅传递参数,运行效率高于DataParallel
下面是一个文本分类的完整示例
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
from datasets import Dataset
from torch.utils.data import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DistributedSampler
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '5678'
# 加载数据
def create_dataset(data_file, tokenizer):
print('create dataset')
with open(data_file, 'r', encoding='utf-8') as f:
data = [_.strip().split('\t') for _ in f.readlines()]
x = [_[0] for _ in data]
y = [int(_[1]) for _ in data]
data_dict = {
'text': x, 'label': y