图像修复优化

完整代码见 refinement.py
相关内容:large mask inpainting (LaMa)图像修复
将inpainter.model,也就是FFCResNetGenerator,分拆为两部分,forward_front部分主要是三次图像下采样,得到潜空间变量z1、z2,经过 forward_rear 部分得到预测结果,使用Adam优化z1、z2。

def refine_predict(batch: dict,inpainter: nn.Module,gpu_ids: str,modulo: int,n_iters: int,lr: float,min_side: int,max_scales: int,px_budget: int,):
    inpainter = inpainter.model
    assert not inpainter.training
    assert not inpainter.add_noise_kwargs
    assert inpainter.concat_mask

    gpu_ids = [
        f"cuda:{gpuid}"
        for gpuid in gpu_ids.replace(" ", "").split(",")
        if gpuid.isdigit()
    ]
    n_resnet_blocks = 0
    first_resblock_ind = 0
    found_first_resblock = False
    for idl in range(len(inpainter.generator.model)):
        if isinstance(inpainter.generator.model[idl], FFCResnetBlock):
            n_resnet_blocks += 1
            found_first_resblock = True
        elif not found_first_resblock:
            first_resblock_ind += 1
    resblocks_per_gpu = n_resnet_blocks // len(gpu_ids)

    devices = [torch.device(gpu_id) for gpu_id in gpu_ids]

    # 把模型拆分为前后两个部分,前部分主要是下采样功能
    forward_front = inpainter.generator.model[0:first_resblock_ind]
    forward_front.to(devices[0])
    forward_rears = []
    for idd in range(len(gpu_ids)):
        if idd < len(gpu_ids) - 1:
            forward_rears.append(
                inpainter.generator.model[
                    first_resblock_ind + resblocks_per_gpu * (idd) : first_resblock_ind
                    + resblocks_per_gpu * (idd + 1)
                ]
            )
        else:
            forward_rears.append(
                inpainter.generator.model[
                    first_resblock_ind + resblocks_per_gpu * (idd) :
                ]
            )
        forward_rears[idd].to(devices[idd])
     #生成图像和mask金字塔
    ls_images, ls_masks = _get_image_mask_pyramid(
        batch, min_side, max_scales, px_budget
    )
    image_inpainted = None
    #训练模型,生成预测结果
    for ids, (image, mask) in enumerate(zip(ls_images, ls_masks)):
        orig_shape = image.shape[2:]
        image = pad_tensor_to_modulo(image, modulo)
        mask = pad_tensor_to_modulo(mask, modulo)
        mask[mask >= 1e-8] = 1.0
        mask[mask < 1e-8] = 0.0
        image, mask = (
            move_to_device(image, devices[0]),
            move_to_device(mask, devices[0]),
        )
        if image_inpainted is not None:
            image_inpainted = move_to_device(image_inpainted, devices[-1])
        image_inpainted = _infer(image,mask,forward_front,forward_rears, image_inpainted,orig_shape,devices,ids,n_iters,lr,)
        image_inpainted = image_inpainted[:, :, : orig_shape[0], : orig_shape[1]]
        # detach everything to save resources
        image = image.detach().cpu()
        mask = mask.detach().cpu()

    return image_inpainted

L1损失函数,默认求预测图像和输入图像的非mask部分的L1 loss。
在预测状态下,并没有mask部分的真实值,使用本次预测结果的下采样和图像金字塔中前一次下采样图像的预测值进行比较,计算loss
所以在图像金字塔中,分辨率按照从低到高排列。
如果图像的尺寸不满足下采样的条件,没有图像金字塔,则无法进行优化,直接返回结果。


def _l1_loss(
    pred: torch.Tensor,
    pred_downscaled: torch.Tensor,
    ref: torch.Tensor,
    mask: torch.Tensor,
    mask_downscaled: torch.Tensor,
    image: torch.Tensor,
    on_pred: bool = True,
):
    """l1 loss on src pixels, and downscaled predictions if on_pred=True"""
    #非mask部分,不需要修复的部分
    loss = torch.mean(torch.abs(pred[mask < 1e-8] - image[mask < 1e-8]))
    #mask部分
    if on_pred:
        loss += torch.mean(
            torch.abs(
                pred_downscaled[mask_downscaled >= 1e-8] - ref[mask_downscaled >= 1e-8]
            )
        )
    return loss

def _infer(
    image: torch.Tensor,
    mask: torch.Tensor,
    forward_front: nn.Module,
    forward_rears: nn.Module,
    ref_lower_res: torch.Tensor,
    orig_shape: tuple,
    devices: list,
    scale_ind: int,
    n_iters: int = 15,
    lr: float = 0.002,
):
    masked_image = image * (1 - mask)
    masked_image = torch.cat([masked_image, mask], dim=1)

    mask = mask.repeat(1, 3, 1, 1)
    if ref_lower_res is not None:
        ref_lower_res = ref_lower_res.detach()
    with torch.no_grad():
        z1, z2 = forward_front(masked_image)
    # Inference
    mask = mask.to(devices[-1])
    ekernel = torch.from_numpy(
        cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 15)).astype(bool)
    ).float()
    ekernel = ekernel.to(devices[-1])
    image = image.to(devices[-1])
    z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0])
    z1.requires_grad, z2.requires_grad = True, True

    optimizer = Adam([z1, z2], lr=lr)

    pbar = tqdm(range(n_iters), leave=False)
    for idi in pbar:
        optimizer.zero_grad()
        input_feat = (z1, z2)
        for idd, forward_rear in enumerate(forward_rears):
            output_feat = forward_rear(input_feat)
            if idd < len(devices) - 1:
                midz1, midz2 = output_feat
                midz1, midz2 = midz1.to(devices[idd + 1]), midz2.to(devices[idd + 1])
                input_feat = (midz1, midz2)
            else:
                pred = output_feat

        if ref_lower_res is None:
            break
        losses = {}
        # scaled loss with downsampler
        #对预测结果下采样
        pred_downscaled = _pyrdown(pred[:, :, : orig_shape[0], : orig_shape[1]])
        #对mask进行下采样,并进行erode操作
        mask_downscaled = _pyrdown_mask(
            mask[:, :1, : orig_shape[0], : orig_shape[1]],
            blur_mask=False,
            round_up=False,
        )
        mask_downscaled = _erode_mask(mask_downscaled, ekernel=ekernel)
        mask_downscaled = mask_downscaled.repeat(1, 3, 1, 1)
        losses["ms_l1"] = _l1_loss(
            pred,
            pred_downscaled,
            ref_lower_res,
            mask,
            mask_downscaled,
            image,
            on_pred=True,
        )

        loss = sum(losses.values())
        pbar.set_description(
            "Refining scale {} using scale {} ...current loss: {:.4f}".format(
                scale_ind + 1, scale_ind, loss.item()
            )
        )
        if idi < n_iters - 1:
            loss.backward()
            optimizer.step()
            del pred_downscaled
            del loss
            del pred
    # "pred" is the prediction after Plug-n-Play module
    inpainted = mask * pred + (1 - mask) * image
    inpainted = inpainted.detach().cpu()
    return inpainted
  • 5
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
将 TensorFlow 代码转换为 PyTorch 代码,可以按照以下步骤进行: 1. 导入 PyTorch 库和相关模块: ```python import torch import torch.nn.functional as F ``` 2. 将 `tf.cast` 方法替换为 `torch.Tensor.to()` 方法: ```python # TensorFlow 代码 masked_response_function = tf.cast(masked_response_function, dtype=hyper_spectral_image.dtype) # PyTorch 代码 masked_response_function = masked_response_function.to(hyper_spectral_image.dtype) ``` 3. 将 `tf.reshape` 方法替换为 `torch.Tensor.view()` 方法: ```python # TensorFlow 代码 response3x3 = tf.reshape(masked_response_function,[3,3,31]) # PyTorch 代码 response3x3 = masked_response_function.view(3, 3, 31) ``` 4. 将 `tf.tile` 方法替换为 `torch.Tensor.repeat()` 方法: ```python # TensorFlow 代码 responsehxw = tf.tile(response3x3,[h//3,w//3,1]) # PyTorch 代码 responsehxw = response3x3.repeat(h // 3, w // 3, 1) ``` 5. 将 `tf.reduce_sum` 方法替换为 `torch.sum()` 方法: ```python # TensorFlow 代码 response_img = tf.reduce_sum(response_img,axis=-1)/ tf.reduce_sum(responsehxw,axis=-1) # PyTorch 代码 response_img = torch.sum(response_img, dim=-1) / torch.sum(responsehxw, dim=-1) ``` 6. 将 TensorFlow 的 `None` 替换为 PyTorch 的 `None`: ```python # TensorFlow 代码 keepdims=True # PyTorch 代码 keepdim=True ``` 7. 将 TensorFlow 的 `shape` 属性替换为 PyTorch 的 `size()` 或 `shape` 方法: ```python # TensorFlow 代码 batch_size,h,w,c = hyper_spectral_image.shape # PyTorch 代码 batch_size, h, w, c = hyper_spectral_image.size() ``` 最终的 PyTorch 代码如下: ```python import torch import torch.nn.functional as F def simulated_spectral9_camera_spectral_response_function(hyper_spectral_image, masked_response_function): masked_response_function = masked_response_function.to(hyper_spectral_image.dtype) hyper_spectral_image = hyper_spectral_image[:, 1:-1, 1:-1, :] batch_size, h, w, c = hyper_spectral_image.size() response3x3 = masked_response_function.view(3, 3, 31) responsehxw = response3x3.repeat(h // 3, w // 3, 1) response_img = hyper_spectral_image * responsehxw response_img = torch.sum(response_img, dim=-1) / torch.sum(responsehxw, dim=-1, keepdim=True) return response_img ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值