Dual Diffusion Implicit Bridges (ICLR 2023)
常见的图像到图像转换方法依赖于对来自源域和目标域的数据进行联合训练。这种训练过程难以保护域数据的隐私,并且通常意味着必须为一对新的域训练新模型。
双重扩散隐式桥接 (DDIB)是一种基于扩散模型的图像翻译方法,可绕过域对训练。 DDIB 的图像翻译依赖于在每个域上独立训练的两个扩散模型,并且是一个两步过程:DDIB 首先使用源扩散模型获得源图像的潜在编码,然后使用目标模型解码这些编码以构建目标图像。 这两个步骤都是通过 ODE 定义的,因此该过程仅在 ODE 求解器的离散化误差方面是循环一致的。从理论上讲,我们将 DDIB 解释为源到潜在、潜在到目标薛定谔桥的串联,这是一种熵正则化的最优传输形式。
这里应用DDIM中的常微分方程求解器ODESolve实现DDIB的两个过程:
伪代码:
DDIB强制精确的周期一致性,无需像CycleGAN一样额外引入cycle loss:
Conditional ImageNet Translation:
source (class 260) ——> target (class 261)
source (class 291) ——> target (class 292):
imagenet_translation.py关键代码:
def main():
args = create_argparser().parse_args()
logger.log(f"arguments: {args}")
dist_util.setup_dist()
logger.configure()
logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu")
)
model.to(dist_util.dev())
if args.use_fp16:
model.convert_to_fp16()
model.eval()
logger.log("loading classifier...")
classifier = create_classifier(**args_to_dict(args, classifier_defaults().keys()))
classifier.load_state_dict(
dist_util.load_state_dict(args.classifier_path, map_location="cpu")
)
classifier.to(dist_util.dev())
if args.classifier_use_fp16:
classifier.convert_to_fp16()
classifier.eval()
def cond_fn(x, t, y=None):
assert y is not None
with th.enable_grad():
x_in = x.detach().requires_grad_(True)
logits = classifier(x_in, t)
log_probs = F.log_softmax(logits, dim=-1)
selected = log_probs[range(len(logits)), y.view(-1)]
return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale
def model_fn(x, t, y=None):
assert y is not None
return model(x, t, y if args.class_cond else None)
# Copies the source dataset from ImageNet validation set.
logger.log("copying source dataset.")
source = [int(v) for v in args.source.split(",")]
target = [int(v) for v in args.target.split(",")]
source_to_target_mapping = {s: t for s, t in zip(source, target)}
copy_imagenet_dataset(args.val_dir, source)
logger.log("running image translation...")
data = load_source_data_for_domain_translation(
batch_size=args.batch_size,
image_size=args.image_size
)
for i, (batch, extra) in enumerate(data):
logger.log(f"translating batch {i}, shape {batch.shape}.")
logger.log("saving the original, cropped images.")
images = ((batch + 1) * 127.5).clamp(0, 255).to(th.uint8)
images = images.permute(0, 2, 3, 1)
images = images.contiguous()
images = images.cpu().numpy()
for index in range(images.shape[0]):
filepath = extra["filepath"][index]
image = Image.fromarray(images[index])
image.save(filepath)
logger.log(f" saving: {filepath}")
batch = batch.to(dist_util.dev())
# Class labels for source and target sets
source_y = dict(y=extra["y"].to(dist_util.dev()))
target_y_list = [source_to_target_mapping[v.item()] for v in extra["y"]]
target_y = dict(y=th.tensor(target_y_list).to(dist_util.dev()))
# First, use DDIM to encode to latents.
logger.log("encoding the source images.")
noise = diffusion.ddim_reverse_sample_loop(
model_fn,
batch,
clip_denoised=False,
model_kwargs=source_y,
device=dist_util.dev(),
)
logger.log(f"obtained latent representation for {batch.shape[0]} samples...")
logger.log(f"latent with mean {noise.mean()} and std {noise.std()}")
# Next, decode the latents to the target class.
sample = diffusion.ddim_sample_loop(
model_fn,
(args.batch_size, 3, args.image_size, args.image_size),
noise=noise,
clip_denoised=args.clip_denoised,
model_kwargs=target_y,
cond_fn=cond_fn,
device=dist_util.dev(),
eta=args.eta
)
sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
sample = sample.permute(0, 2, 3, 1)
sample = sample.contiguous()
images = []
gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
images.extend([sample.cpu().numpy() for sample in gathered_samples])
logger.log(f"created {len(images) * args.batch_size} samples")
logger.log("saving translated images.")
images = np.concatenate(images, axis=0)
for index in range(images.shape[0]):
base_dir, filename = os.path.split(extra["filepath"][index])
filename, ext = filename.split(".")
filepath = os.path.join(base_dir, f"{filename}_translated_{target_y_list[index]}.{ext}")
image = Image.fromarray(images[index])
image.save(filepath)
logger.log(f" saving: {filepath}")
dist.barrier()
logger.log(f"domain translation complete")
关键log:
creating model and diffusion…
loading classifier…
copying source dataset.
Copying image files for class 291.
running image translation…
translating batch 1, shape torch.Size([8, 3, 256, 256]).
saving the original, cropped images.
saving: ./experiments/imagenet/291_17.JPG
saving: ./experiments/imagenet/291_18.JPG
saving: ./experiments/imagenet/291_19.JPG
saving: ./experiments/imagenet/291_2.JPG
saving: ./experiments/imagenet/291_20.JPG
saving: ./experiments/imagenet/291_21.JPG
saving: ./experiments/imagenet/291_22.JPG
saving: ./experiments/imagenet/291_23.JPG
encoding the source images.
obtained latent representation for 8 samples…
latent with mean -0.0013304136227816343 and std 0.9858766198158264
created 8 samples
saving translated images.
saving: ./experiments/imagenet/291_17_translated_292.JPG
saving: ./experiments/imagenet/291_18_translated_292.JPG
saving: ./experiments/imagenet/291_19_translated_292.JPG
saving: ./experiments/imagenet/291_2_translated_292.JPG
saving: ./experiments/imagenet/291_20_translated_292.JPG
saving: ./experiments/imagenet/291_21_translated_292.JPG
saving: ./experiments/imagenet/291_22_translated_292.JPG
saving: ./experiments/imagenet/291_23_translated_292.JPG
translating batch 2, shape torch.Size([8, 3, 256, 256]).
saving the original, cropped images.
saving: ./experiments/imagenet/291_24.JPG
saving: ./experiments/imagenet/291_25.JPG
…
隐私敏感型图像翻译
在翻译过程中,只有潜码和翻译后的图像通过公共通道传输,源数据集和目标数据集对双方都是私有的。这是ddib相对于其他方法的一个显著优势,因为我们支持对数据集的强隐私保护。