pytorch多进程处理数据的代码模板

pytorch多进程处理数据的代码模板

和单进程的主要区别:

  1. 增加初始化pytorch多进程的函数 “init_distributed_mode”
  2. dataloader的sampler需要设置是分布式sampler
  3. 启动命令发生变化,需要以“python -m torch.distributed.launch --nproc_per_node=8 --master_port 11113”为开头
import argparse
import functools
import gc
import logging
import math
import os
import random
import shutil
from pathlib import Path

import accelerate
import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig

import diffusers
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    StableDiffusionXLPipeline,
    UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
import cv2
from PIL import Image
from torch.utils import data
from distributed import init_distributed_mode,  get_rank, get_world_size,end_distributed_mode



class VAECache(torch.utils.data.Dataset):
    def __init__(self, args):
        super().__init__()
        self.args = args
        data = []
       	... # 完成数据地址等信息的获取
        self.data = data
        print(f'dataset length: {len(data)}')
         # Preprocessing the datasets.
         # 数据预处理
        self.train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
        self.train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
        self.train_flip = transforms.RandomHorizontalFlip(p=1.0)
        self.train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

    

    def __len__(self,):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, kpt_file, img_key = self.data[idx]
        tmp = img_path.split('.')[0]
        class_name, img_key = tmp.split('/')[-2:]
        # print(img_path)
        img = cv2.imread(img_path)
        # img = cv2.resize(img, dsize=(640, 640), interpolation=cv2.INTER_LINEAR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img)
        image = self.preprocess_train(img)
        return image

    def preprocess_train(self, image):
        # original_sizes.append((image.height, image.width))
        res = (image.height, image.width)
        image = self.train_resize(image)
        if args.center_crop:
            y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
            x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
            image = self.train_crop(image)
        else:
            y1, x1, h, w = self.train_crop.get_params(image, (args.resolution, args.resolution))
            image = crop(image, y1, x1, h, w)
        if args.random_flip and random.random() < 0.5:
            # flip
            x1 = image.width - x1
            image = self.train_flip(image)
        crop_top_left = (y1, x1)
      
        image = self.train_transforms(image)
        return  image



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--vae_path' , type=str)
    parser.add_argument('--bs', type=int)
    parser.add_argument('--num_workers', type=int)
    parser.add_argument('--resolution', type=int, default=1024)
    parser.add_argument('--center_crop', type=bool, default=False)
    parser.add_argument('--random_flip', type=bool, default=True)
    parser.add_argument('--local-rank', type=int)
    parser.add_argument('--dist_url', default="env://")

    args = parser.parse_args()
		
	# 多进程初始化
    init_distributed_mode(args)
	# 建立数据集
    dataset = VAECache(args)

    sampler_val = None
    if args.distributed:
        global_rank = get_rank()
        sampler_val = torch.utils.data.DistributedSampler(
            dataset, num_replicas=get_world_size(), rank=global_rank, shuffle=False)
    else:
    	# 强制使用多进程
        assert False
    # 建立dataloader
    dataloader = data.DataLoader(dataset, batch_size=args.bs,
                                 sampler=sampler_val,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=args.num_workers)

   
    # 建立模型,无需DDP
    vae = AutoencoderKL.from_pretrained(
        "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
    )
    torch.distributed.barrier()
    vae.requires_grad_(False)
    vae.to(args.device) #,  dtype=torch.float16)
    print(f'vae dtype: {vae.dtype}')
    # 主进程显示进度条
    if args.rank == 0:
        dataloader = tqdm(dataloader)
    for batch in dataloader:
        image = batch
        images = image.to(args.device, dtype=vae.dtype)
        with torch.no_grad():
            model_input = vae.encode(images).latent_dist.sample()
        model_input = model_input * vae.config.scaling_factor
      	... # 执行保存操作或者其他

    torch.distributed.barrier()


# 启动指令,复制到命令行执行
'''
python -m torch.distributed.launch --nproc_per_node=8 sdxl_vae_16bit_cache.py \
--bs 8 \
--num_workers 10
'''

执行

python -m torch.distributed.launch --nproc_per_node=8 sdxl_vae_16bit_cache.py --bs 8 --num_workers 10

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值