最新虚拟试衣框架IMAGDressing模型部署

IMAGDressing是一个全新的虚拟试衣框架,它由南京理工大学、武汉理工大学、腾讯AI实验室和南京大学共同开发。

该项目旨在通过先进的技术提升消费者的在线购物体验,特别是通过虚拟试穿技术(VTON)来实现逼真的服装效果。

IMAGDressing定义了一个新的虚拟穿衣任务,专注于生成具有固定服装和可选条件的自由可编辑人物图像,同时设计了一种全面的亲和度指标来评估生成图像与参考服装之间的一致性。

此外,IMAGDressing-v1还结合了一个服装UNet,该UNet从CLIP捕获语义特征,从VAE捕获纹理特征,并引入了一个混合注意力模块,包括冻结的自注意力和可训练的交叉注意力,以将服装特征整合到冻结的去噪UNet中,确保用户可以通过文本控制不同的场景。

github项目地址:https://github.com/muzishen/IMAGDressing。

一、环境安装

1、python环境

建议安装python版本在3.10以上。

2、pip库安装

pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

3、IMAGDressing模型下载

git lfs install

git clone https://huggingface.co/feishen29/IMAGDressing

4、sd-vae-ft-mse模型下载

git lfs install

git clone https://huggingface.co/stabilityai/sd-vae-ft-mse

5、Realistic_Vision_V4.0_noVAE模型下载

git lfs install

git clone https://huggingface.co/SG161222/Realistic_Vision_V4.0_noVAE

6、IP-Adapter-FaceID模型下载

git lfs install

git clone https://huggingface.co/h94/IP-Adapter-FaceID

7、control_v11p_sd15_openpose模型下载

git lfs install

git clone https://huggingface.co/lllyasviel/control_v11p_sd15_openpose

8、IP-Adapter模型下载

git lfs install

git clone https://huggingface.co/h94/IP-Adapter

9、IDM-VTON模型下载

git lfs install

git clone https://huggingface.co/spaces/yisol/IDM-VTON

、功能测试

1、命令行运行测试

(1)指定服装的python代码测试

import os
import torch
from PIL import Image
from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler
from torchvision import transforms
from transformers import CLIPImageProcessor
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from adapter.attention_processor import CacheAttnProcessor2_0, RefSAttnProcessor2_0, CAttnProcessor2_0
import argparse
from adapter.resampler import Resampler
from dressing_sd.pipelines.IMAGDressing_v1_pipeline import IMAGDressing_v1

def resize_img(input_image, max_side=640, min_side=512, size=None,
               pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
    w, h = input_image.size
    ratio = min_side / min(h, w)
    w, h = round(ratio * w), round(ratio * h)
    ratio = max_side / max(h, w)
    input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
    w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
    h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
    input_image = input_image.resize([w_resize_new, h_resize_new], mode)
    return input_image

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows * cols
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    grid_w, grid_h = grid.size
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

def prepare(args):
    generator = torch.Generator(device=args.device).manual_seed(42)
    vae = AutoencoderKL.from_pretrained("path/to/sd-vae-ft-mse").to(dtype=torch.float16, device=args.device)
    tokenizer = CLIPTokenizer.from_pretrained("path/to/Realistic_Vision_V4.0_noVAE", subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained("path/to/Realistic_Vision_V4.0_noVAE", subfolder="text_encoder").to(dtype=torch.float16, device=args.device)
    image_encoder = CLIPVisionModelWithProjection.from_pretrained("path/to/IP-Adapter", subfolder="models/image_encoder").to(dtype=torch.float16, device=args.device)
    unet = UNet2DConditionModel.from_pretrained("path/to/Realistic_Vision_V4.0_noVAE", subfolder="unet").to(dtype=torch.float16, device=args.device)

    image_proj = Resampler(dim=unet.config.cross_attention_dim, depth=4, dim_head=64, heads=12, num_queries=16, embedding_dim=image_encoder.config.hidden_size, output_dim=unet.config.cross_attention_dim, ff_mult=4).to(dtype=torch.float16, device=args.device)

    attn_procs = {}
    for name in unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = unet.config.block_out_channels[-1]
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = unet.config.block_out_channels[block_id]
        if cross_attention_dim is None:
            attn_procs[name] = RefSAttnProcessor2_0(name, hidden_size)
        else:
            attn_procs[name] = CAttnProcessor2_0(name, hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
    unet.set_attn_processor(attn_procs)
    adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()).to(dtype=torch.float16, device=args.device)

    ref_unet = UNet2DConditionModel.from_pretrained("path/to/Realistic_Vision_V4.0_noVAE", subfolder="unet").to(dtype=torch.float16, device=args.device)
    ref_unet.set_attn_processor({name: CacheAttnProcessor2_0() for name in ref_unet.attn_processors.keys()})

    model_sd = torch.load(args.model_ckpt, map_location="cpu")["module"]

    ref_unet_dict = {}
    unet_dict = {}
    image_proj_dict = {}
    adapter_modules_dict = {}
    for k, v in model_sd.items():
        if k.startswith("ref_unet"):
            ref_unet_dict[k.replace("ref_unet.", "")] = v
        elif k.startswith("unet"):
            unet_dict[k.replace("unet.", "")] = v
        elif k.startswith("proj"):
            image_proj_dict[k.replace("proj.", "")] = v
        elif k.startswith("adapter_modules"):
            adapter_modules_dict[k.replace("adapter_modules.", "")] = v

    ref_unet.load_state_dict(ref_unet_dict)
    image_proj.load_state_dict(image_proj_dict)
    adapter_modules.load_state_dict(adapter_modules_dict)

    noise_scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", steps_offset=1)

    pipe = IMAGDressing_v1(unet=unet, reference_unet=ref_unet, vae=vae, tokenizer=tokenizer,
                         text_encoder=text_encoder, image_encoder=image_encoder,
                         ImgProj=image_proj, scheduler=noise_scheduler,
                         safety_checker=StableDiffusionSafetyChecker,
                         feature_extractor=CLIPImageProcessor())
    return pipe, generator

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='IMAGDressing_v1')
    parser.add_argument('--model_ckpt', default="path/to/IMAGDressing-v1_512.pt", type=str)
    parser.add_argument('--cloth_path', type=str, required=True)
    parser.add_argument('--output_path', type=str, default="./output_sd_base")
    parser.add_argument('--device', type=str, default="cuda:0")
    args = parser.parse_args()

    output_path = args.output_path
    os.makedirs(output_path, exist_ok=True)

    pipe, generator = prepare(args)
    print('====================== Pipe loaded successfully ===================')

    num_samples = 1
    clip_image_processor = CLIPImageProcessor()

    img_transform = transforms.Compose([
        transforms.Resize([640, 512], interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

    prompt = 'A beautiful woman, best quality, high quality'
    null_prompt = ''
    negative_prompt = 'bare, naked, nude, undressed, monochrome, lowres, bad anatomy, worst quality, low quality'

    clothes_img = Image.open(args.cloth_path).convert("RGB")
    clothes_img = resize_img(clothes_img)
    vae_clothes = img_transform(clothes_img).unsqueeze(0).to(args.device)
    ref_clip_image = clip_image_processor(images=clothes_img, return_tensors="pt").pixel_values.to(args.device)

    output = pipe(ref_image=vae_clothes, prompt=prompt, ref_clip_image=ref_clip_image, null_prompt=null_prompt, negative_prompt=negative_prompt, width=512, height=640, num_images_per_prompt=num_samples, guidance_scale=7.5, image_scale=1.0, generator=generator, num_inference_steps=50).images

    save_output = [clothes_img.resize((512, 640), Image.BICUBIC)]
    save_output.append(output[0])

    grid = image_grid(save_output, 1, 2)
    grid.save(os.path.join(output_path, os.path.basename(args.cloth_path)))

未完......

更多详细的内容欢迎关注:杰哥新技术

  • 19
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值