完整代码见 refinement.py
相关内容:large mask inpainting (LaMa)图像修复
将inpainter.model,也就是FFCResNetGenerator,分拆为两部分,forward_front部分主要是三次图像下采样,得到潜空间变量z1、z2,经过 forward_rear 部分得到预测结果,使用Adam优化z1、z2。
def refine_predict(batch: dict,inpainter: nn.Module,gpu_ids: str,modulo: int,n_iters: int,lr: float,min_side: int,max_scales: int,px_budget: int,):
inpainter = inpainter.model
assert not inpainter.training
assert not inpainter.add_noise_kwargs
assert inpainter.concat_mask
gpu_ids = [
f"cuda:{gpuid}"
for gpuid in gpu_ids.replace(" ", "").split(",")
if gpuid.isdigit()
]
n_resnet_blocks = 0
first_resblock_ind = 0
found_first_resblock = False
for idl in range(len(inpainter.generator.model)):
if isinstance(inpainter.generator.model[idl], FFCResnetBlock):
n_resnet_blocks += 1
found_first_resblock = True
elif not found_first_resblock:
first_resblock_ind += 1
resblocks_per_gpu = n_resnet_blocks // len(gpu_ids)
devices = [torch.device(gpu_id) for gpu_id in gpu_ids]
# 把模型拆分为前后两个部分,前部分主要是下采样功能
forward_front = inpainter.generator.model[0:first_resblock_ind]
forward_front.to(devices[0])
forward_rears = []
for idd in range(len(gpu_ids)):
if idd < len(gpu_ids) - 1:
forward_rears.append(
inpainter.generator.model[
first_resblock_ind + resblocks_per_gpu * (idd) : first_resblock_ind
+ resblocks_per_gpu * (idd + 1)
]
)
else:
forward_rears.append(
inpainter.generator.model[
first_resblock_ind + resblocks_per_gpu * (idd) :
]
)
forward_rears[idd].to(devices[idd])
#生成图像和mask金字塔
ls_images, ls_masks = _get_image_mask_pyramid(
batch, min_side, max_scales, px_budget
)
image_inpainted = None
#训练模型,生成预测结果
for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)):
orig_shape = image.shape[2:]
image = pad_tensor_to_modulo(image, modulo)
mask = pad_tensor_to_modulo(mask, modulo)
mask[mask >= 1e-8] = 1.0
mask[mask < 1e-8] = 0.0
image, mask = (
move_to_device(image, devices[0]),
move_to_device(mask, devices[0]),
)
if image_inpainted is not None:
image_inpainted = move_to_device(image_inpainted, devices[-1])
image_inpainted = _infer(image,mask,forward_front,forward_rears, image_inpainted,orig_shape,devices,ids,n_iters,lr,)
image_inpainted = image_inpainted[:, :, : orig_shape[0], : orig_shape[1]]
# detach everything to save resources
image = image.detach().cpu()
mask = mask.detach().cpu()
return image_inpainted
L1损失函数,默认求预测图像和输入图像的非mask部分的L1 loss。
在预测状态下,并没有mask部分的真实值,使用本次预测结果的下采样和图像金字塔中前一次下采样图像的预测值进行比较,计算loss。
所以在图像金字塔中,分辨率按照从低到高排列。
如果图像的尺寸不满足下采样的条件,没有图像金字塔,则无法进行优化,直接返回结果。
def _l1_loss(
pred: torch.Tensor,
pred_downscaled: torch.Tensor,
ref: torch.Tensor,
mask: torch.Tensor,
mask_downscaled: torch.Tensor,
image: torch.Tensor,
on_pred: bool = True,
):
"""l1 loss on src pixels, and downscaled predictions if on_pred=True"""
#非mask部分,不需要修复的部分
loss = torch.mean(torch.abs(pred[mask < 1e-8] - image[mask < 1e-8]))
#mask部分
if on_pred:
loss += torch.mean(
torch.abs(
pred_downscaled[mask_downscaled >= 1e-8] - ref[mask_downscaled >= 1e-8]
)
)
return loss
def _infer(
image: torch.Tensor,
mask: torch.Tensor,
forward_front: nn.Module,
forward_rears: nn.Module,
ref_lower_res: torch.Tensor,
orig_shape: tuple,
devices: list,
scale_ind: int,
n_iters: int = 15,
lr: float = 0.002,
):
masked_image = image * (1 - mask)
masked_image = torch.cat([masked_image, mask], dim=1)
mask = mask.repeat(1, 3, 1, 1)
if ref_lower_res is not None:
ref_lower_res = ref_lower_res.detach()
with torch.no_grad():
z1, z2 = forward_front(masked_image)
# Inference
mask = mask.to(devices[-1])
ekernel = torch.from_numpy(
cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)).astype(bool)
).float()
ekernel = ekernel.to(devices[-1])
image = image.to(devices[-1])
z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0])
z1.requires_grad, z2.requires_grad = True, True
optimizer = Adam([z1, z2], lr=lr)
pbar = tqdm(range(n_iters), leave=False)
for idi in pbar:
optimizer.zero_grad()
input_feat = (z1, z2)
for idd, forward_rear in enumerate(forward_rears):
output_feat = forward_rear(input_feat)
if idd < len(devices) - 1:
midz1, midz2 = output_feat
midz1, midz2 = midz1.to(devices[idd + 1]), midz2.to(devices[idd + 1])
input_feat = (midz1, midz2)
else:
pred = output_feat
if ref_lower_res is None:
break
losses = {}
# scaled loss with downsampler
#对预测结果下采样
pred_downscaled = _pyrdown(pred[:, :, : orig_shape[0], : orig_shape[1]])
#对mask进行下采样,并进行erode操作
mask_downscaled = _pyrdown_mask(
mask[:, :1, : orig_shape[0], : orig_shape[1]],
blur_mask=False,
round_up=False,
)
mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel)
mask_downscaled = mask_downscaled.repeat(1, 3, 1, 1)
losses["ms_l1"] = _l1_loss(
pred,
pred_downscaled,
ref_lower_res,
mask,
mask_downscaled,
image,
on_pred=True,
)
loss = sum(losses.values())
pbar.set_description(
"Refining scale {} using scale {} ...current loss: {:.4f}".format(
scale_ind + 1, scale_ind, loss.item()
)
)
if idi < n_iters - 1:
loss.backward()
optimizer.step()
del pred_downscaled
del loss
del pred
# "pred" is the prediction after Plug-n-Play module
inpainted = mask * pred + (1 - mask) * image
inpainted = inpainted.detach().cpu()
return inpainted