Brushnet复现过程4-15

使用环境pytorch2023

项目地址:TencentARC/BrushNet: [ECCV 2024] The official implementation of paper "BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion" (github.com)

下载的数据集:项目中提到的BrushData的文件夹中的00200.tar,BeushDench一整个文件夹,ckpt文件夹中的segmentation_mask_brushnet_ckpt

需要将下载的文件整理成如下结构:

|-- data
    |-- BrushData
    |-- BrushDench
    |-- EditBench
    |-- ckpt
        |-- realisticVisionV60B1_v51VAE
            |-- model_index.json
            |-- vae
            |-- ...
        |-- segmentation_mask_brushnet_ckpt
        |-- segmentation_mask_brushnet_ckpt_sdxl_v0
        |-- random_mask_brushnet_ckpt
        |-- random_mask_brushnet_ckpt_sdxl_v0
        |-- ...

训练命令:

accelerate launch examples/brushnet/train_brushnet.py --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 --output_dir runs/logs/brushnet_segmentationmask --train_data_dir data/BrushData --resolution 512 --learning_rate 1e-5 --train_batch_size 2 --tracker_project_name brushnet --report_to tensorboard --resume_from_checkpoint latest --validation_steps 300 --checkpointing_steps 10000

运行指令后又出现了连接不上huggingface的情况,在examples/brushnet/train_brushnet.py文件开头加上如下两行代码成功解决

import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

训练过程如下

dabug1

File "D:\桌面\BrushNet-main\src\diffusers\utils\dynamic_modules_utils.py", line 28, in <module> from huggingface_hub import cached_download, hf_hub_download, model_info ImportError: cannot import name 'cached_download' from 'huggingface_hub'

原因:huggingface_hub 版本不兼容​

解决方法:更新代码以适配新版库

# 原代码
from huggingface_hub import cached_download

# 修改为
from huggingface_hub import hf_hub_download as cached_download

debug2

File "D:\桌面\BrushNet-main\src\diffusers\utils\import_utils.py", line 720, in _get_module raise RuntimeError( RuntimeError: Failed to import diffusers.models.autoencoders.autoencoder_kl because of the following error (look up to see its traceback): Failed to import transformers.models.clip.modeling_clip because of the following error (look up to see its traceback): No module named 'torch.distributed.tensor'

原因:PyTorch的分布式张量模块缺失导致依赖链断裂

解决方法:强制升级PyTorch

pip install torch==2.1.2+cu121 --extra-index-url https://download.pytorch.org/whl/nightly/cu121

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值