pytorch多进程处理数据的代码模板
和单进程的主要区别:
- 增加初始化pytorch多进程的函数 “init_distributed_mode”
- dataloader的sampler需要设置是分布式sampler
- 启动命令发生变化,需要以“python -m torch.distributed.launch --nproc_per_node=8 --master_port 11113”为开头
import argparse
import functools
import gc
import logging
import math
import os
import random
import shutil
from pathlib import Path
import accelerate
import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
import cv2
from PIL import Image
from torch.utils import data
from distributed import init_distributed_mode, get_rank, get_world_size,end_distributed_mode
class VAECache(torch.utils.data.Dataset):
def __init__(self, args):
super().__init__()
self.args = args
data = []
... # 完成数据地址等信息的获取
self.data = data
print(f'dataset length: {len(data)}')
# Preprocessing the datasets.
# 数据预处理
self.train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
self.train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
self.train_flip = transforms.RandomHorizontalFlip(p=1.0)
self.train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
def __len__(self,):
return len(self.data)
def __getitem__(self, idx):
img_path, kpt_file, img_key = self.data[idx]
tmp = img_path.split('.')[0]
class_name, img_key = tmp.split('/')[-2:]
# print(img_path)
img = cv2.imread(img_path)
# img = cv2.resize(img, dsize=(640, 640), interpolation=cv2.INTER_LINEAR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img)
image = self.preprocess_train(img)
return image
def preprocess_train(self, image):
# original_sizes.append((image.height, image.width))
res = (image.height, image.width)
image = self.train_resize(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
image = self.train_crop(image)
else:
y1, x1, h, w = self.train_crop.get_params(image, (args.resolution, args.resolution))
image = crop(image, y1, x1, h, w)
if args.random_flip and random.random() < 0.5:
# flip
x1 = image.width - x1
image = self.train_flip(image)
crop_top_left = (y1, x1)
image = self.train_transforms(image)
return image
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--vae_path' , type=str)
parser.add_argument('--bs', type=int)
parser.add_argument('--num_workers', type=int)
parser.add_argument('--resolution', type=int, default=1024)
parser.add_argument('--center_crop', type=bool, default=False)
parser.add_argument('--random_flip', type=bool, default=True)
parser.add_argument('--local-rank', type=int)
parser.add_argument('--dist_url', default="env://")
args = parser.parse_args()
# 多进程初始化
init_distributed_mode(args)
# 建立数据集
dataset = VAECache(args)
sampler_val = None
if args.distributed:
global_rank = get_rank()
sampler_val = torch.utils.data.DistributedSampler(
dataset, num_replicas=get_world_size(), rank=global_rank, shuffle=False)
else:
# 强制使用多进程
assert False
# 建立dataloader
dataloader = data.DataLoader(dataset, batch_size=args.bs,
sampler=sampler_val,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers)
# 建立模型,无需DDP
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
torch.distributed.barrier()
vae.requires_grad_(False)
vae.to(args.device) #, dtype=torch.float16)
print(f'vae dtype: {vae.dtype}')
# 主进程显示进度条
if args.rank == 0:
dataloader = tqdm(dataloader)
for batch in dataloader:
image = batch
images = image.to(args.device, dtype=vae.dtype)
with torch.no_grad():
model_input = vae.encode(images).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
... # 执行保存操作或者其他
torch.distributed.barrier()
# 启动指令,复制到命令行执行
'''
python -m torch.distributed.launch --nproc_per_node=8 sdxl_vae_16bit_cache.py \
--bs 8 \
--num_workers 10
'''
执行
python -m torch.distributed.launch --nproc_per_node=8 sdxl_vae_16bit_cache.py --bs 8 --num_workers 10