DDIB论文中ImageNet Translation复现

Dual Diffusion Implicit Bridges (ICLR 2023)

paper
code

  常见的图像到图像转换方法依赖于对来自源域和目标域的数据进行联合训练。这种训练过程难以保护域数据的隐私,并且通常意味着必须为一对新的域训练新模型。

  双重扩散隐式桥接 (DDIB)是一种基于扩散模型的图像翻译方法,可绕过域对训练。 DDIB 的图像翻译依赖于在每个域上独立训练的两个扩散模型,并且是一个两步过程:DDIB 首先使用源扩散模型获得源图像的潜在编码,然后使用目标模型解码这些编码以构建目标图像。 这两个步骤都是通过 ODE 定义的,因此该过程仅在 ODE 求解器的离散化误差方面是循环一致的。从理论上讲,我们将 DDIB 解释为源到潜在、潜在到目标薛定谔桥的串联,这是一种熵正则化的最优传输形式。

在这里插入图片描述
这里应用DDIM中的常微分方程求解器ODESolve实现DDIB的两个过程:
在这里插入图片描述
伪代码:
在这里插入图片描述
DDIB强制精确的周期一致性,无需像CycleGAN一样额外引入cycle loss:
在这里插入图片描述


Conditional ImageNet Translation:

source (class 260) ——> target (class 261)
在这里插入图片描述
source (class 291) ——> target (class 292):
在这里插入图片描述
imagenet_translation.py关键代码:

def main():
    args = create_argparser().parse_args()
    logger.log(f"arguments: {args}")

    dist_util.setup_dist()
    logger.configure()

    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    model.load_state_dict(
        dist_util.load_state_dict(args.model_path, map_location="cpu")
    )
    model.to(dist_util.dev())
    if args.use_fp16:
        model.convert_to_fp16()
    model.eval()

    logger.log("loading classifier...")
    classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys()))
    classifier.load_state_dict(
        dist_util.load_state_dict(args.classifier_path, map_location="cpu")
    )
    classifier.to(dist_util.dev())
    if args.classifier_use_fp16:
        classifier.convert_to_fp16()
    classifier.eval()

    def cond_fn(x, t, y=None):
        assert y is not None
        with th.enable_grad():
            x_in = x.detach().requires_grad_(True)
            logits = classifier(x_in, t)
            log_probs = F.log_softmax(logits, dim=-1)
            selected = log_probs[range(len(logits)), y.view(-1)]
            return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

    def model_fn(x, t, y=None):
        assert y is not None
        return model(x, t, y if args.class_cond else None)

    # Copies the source dataset from ImageNet validation set.
    logger.log("copying source dataset.")
    source = [int(v) for v in args.source.split(",")]
    target = [int(v) for v in args.target.split(",")]
    source_to_target_mapping = {s: t for s, t in zip(source, target)}
    copy_imagenet_dataset(args.val_dir, source)

    logger.log("running image translation...")
    data = load_source_data_for_domain_translation(
        batch_size=args.batch_size,
        image_size=args.image_size
    )

    for i, (batch, extra) in enumerate(data):
        logger.log(f"translating batch {i}, shape {batch.shape}.")

        logger.log("saving the original, cropped images.")
        images = ((batch + 1) * 127.5).clamp(0, 255).to(th.uint8)
        images = images.permute(0, 2, 3, 1)
        images = images.contiguous()
        images = images.cpu().numpy()
        for index in range(images.shape[0]):
            filepath = extra["filepath"][index]
            image = Image.fromarray(images[index])
            image.save(filepath)
            logger.log(f"    saving: {filepath}")

        batch = batch.to(dist_util.dev())

        # Class labels for source and target sets
        source_y = dict(y=extra["y"].to(dist_util.dev()))
        target_y_list = [source_to_target_mapping[v.item()] for v in extra["y"]]
        target_y = dict(y=th.tensor(target_y_list).to(dist_util.dev()))

        # First, use DDIM to encode to latents.
        logger.log("encoding the source images.")
        noise = diffusion.ddim_reverse_sample_loop(
            model_fn,
            batch,
            clip_denoised=False,
            model_kwargs=source_y,
            device=dist_util.dev(),
        )
        logger.log(f"obtained latent representation for {batch.shape[0]} samples...")
        logger.log(f"latent with mean {noise.mean()} and std {noise.std()}")

        # Next, decode the latents to the target class.
        sample = diffusion.ddim_sample_loop(
            model_fn,
            (args.batch_size, 3, args.image_size, args.image_size),
            noise=noise,
            clip_denoised=args.clip_denoised,
            model_kwargs=target_y,
            cond_fn=cond_fn,
            device=dist_util.dev(),
            eta=args.eta
        )
        sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
        sample = sample.permute(0, 2, 3, 1)
        sample = sample.contiguous()

        images = []
        gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
        dist.all_gather(gathered_samples, sample)  # gather not supported with NCCL
        images.extend([sample.cpu().numpy() for sample in gathered_samples])
        logger.log(f"created {len(images) * args.batch_size} samples")

        logger.log("saving translated images.")
        images = np.concatenate(images, axis=0)

        for index in range(images.shape[0]):
            base_dir, filename = os.path.split(extra["filepath"][index])
            filename, ext = filename.split(".")
            filepath = os.path.join(base_dir, f"{filename}_translated_{target_y_list[index]}.{ext}")
            image = Image.fromarray(images[index])
            image.save(filepath)
            logger.log(f"    saving: {filepath}")

    dist.barrier()
    logger.log(f"domain translation complete")

关键log:

creating model and diffusion…
loading classifier…
copying source dataset.
Copying image files for class 291.
running image translation…
translating batch 1, shape torch.Size([8, 3, 256, 256]).
saving the original, cropped images.
saving: ./experiments/imagenet/291_17.JPG
saving: ./experiments/imagenet/291_18.JPG
saving: ./experiments/imagenet/291_19.JPG
saving: ./experiments/imagenet/291_2.JPG
saving: ./experiments/imagenet/291_20.JPG
saving: ./experiments/imagenet/291_21.JPG
saving: ./experiments/imagenet/291_22.JPG
saving: ./experiments/imagenet/291_23.JPG
encoding the source images.
obtained latent representation for 8 samples…
latent with mean -0.0013304136227816343 and std 0.9858766198158264
created 8 samples
saving translated images.
saving: ./experiments/imagenet/291_17_translated_292.JPG
saving: ./experiments/imagenet/291_18_translated_292.JPG
saving: ./experiments/imagenet/291_19_translated_292.JPG
saving: ./experiments/imagenet/291_2_translated_292.JPG
saving: ./experiments/imagenet/291_20_translated_292.JPG
saving: ./experiments/imagenet/291_21_translated_292.JPG
saving: ./experiments/imagenet/291_22_translated_292.JPG
saving: ./experiments/imagenet/291_23_translated_292.JPG
translating batch 2, shape torch.Size([8, 3, 256, 256]).
saving the original, cropped images.
saving: ./experiments/imagenet/291_24.JPG
saving: ./experiments/imagenet/291_25.JPG


隐私敏感型图像翻译

在这里插入图片描述

在翻译过程中,只有潜码和翻译后的图像通过公共通道传输,源数据集和目标数据集对双方都是私有的。这是ddib相对于其他方法的一个显著优势,因为我们支持对数据集的强隐私保护。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值