Asymmetric Mask Scheme for Self-Supervised Real Image Denoising:down susampling mask

Asymmetric Mask Scheme for Self-Supervised Real Image Denoising

1.bsn 的特点

首先由提取特征得到feature map
然后blind spot conv,提取非中心的特征

然后必须要用dilated conv 才能使网络不学习 恒等信息。

可以查看ap-bsn 论文的net理解。
在这里插入图片描述

2.mae的灵感

由于bsn的特点是要使用一些被限制的滤波器,因此网络设计要受很多限制。

在MAE中即使图像被mask,仍然有可能被恢复,因此作者设计了mask based net.

直接对图像进行mask然后恢复,避免了bsn的限制,bsn可以用一般网络替代。

在这里插入图片描述

3.解决空间相关性

也是利用 pixel downsampling, 只不过增加了一个mask

在这里插入图片描述

4.asymmetric mask scheme

在这里插入图片描述

训练的时候只使用一个branch网络
推理的时候也是一个branch网络,只不过输入2个互补的mask图像。

推理的时候mask满足, 多个分支 所有被mask的像素 构成整幅图像
在这里插入图片描述

5.shuffle带来的棋盘效应

引入新的损失函数进行fine tune

在这里插入图片描述

6.代码

6.1.图像分解 和 组合

利用pixel_shuffle和pixel_unshuffle函数, 对应figure5的上部分

def pd_down(x: torch.Tensor, pd_factor: int = 5, pad: int = 0) -> torch.Tensor:
    b, c, h, w = x.shape
    x_down = F.pixel_unshuffle(x, pd_factor)
    out = x_down.view(b, c, pd_factor, pd_factor, h // pd_factor, w // pd_factor).permute(
        0, 2, 3, 1, 4, 5).reshape(b * pd_factor * pd_factor, c, h // pd_factor, w // pd_factor)
    return out


def pd_up(out: torch.Tensor, pd_factor: int = 5, pad: int = 0) -> torch.Tensor:
    b, c, h, w = out.shape
    # Reshape the output tensor to its original shape after pixel unshuffle
    x_down = out.view(b // (pd_factor ** 2), pd_factor, pd_factor, c, h,
                      w).permute(0, 3, 1, 2, 4, 5)
    x_down = x_down.reshape(b // (pd_factor ** 2), c *
                            pd_factor * pd_factor, h, w)
    # Use pixel shuffle to upsample the tensor
    x_up = F.pixel_shuffle(x_down, pd_factor)
    return x_up

6.2训练的流程和 测试的流程

class MultiMaskPdDn(nn.Module):
    def __init__(self, pd_train: int = 5, pd_val: int = 2, dn_net: str = 'default', r3: float = -1, r3_num: int = 8,
                 net_param: dict[str, float | str | int] = None, **kwargs):
        super().__init__()

        self.dn = dn_dict[dn_net](**net_param if net_param is not None else {})
        self.pd_train = pd_train
        self.pd_val = pd_val
        self.r3 = R3(r3, r3_num)

    def denoise(self, x: torch.Tensor, pd_factor: int = None, return_mask: bool = False, only_first: bool = True) -> torch.Tensor:
        # 下采样,训练,上采样
        if pd_factor is None:
            pd_factor = self.pd_train
        if pd_factor > 1:
            x = util.pd_down(x, pd_factor)
        if return_mask:
            dn_img, masks = self.dn(x, True, only_first)
        else:
            dn_img = self.dn(x, False, only_first)
        if pd_factor > 1:
            dn_img = util.pd_up(dn_img, pd_factor)
        return dn_img if not return_mask else (dn_img, masks)

    def forward(self, x: torch.Tensor, pd_factor: int = None, return_mask: bool = False, only_first: bool = True) -> torch.Tensor:
        if self.training:
            return self.denoise(x, pd_factor, return_mask, only_first)
        else:
            # 测试相比训练多了个r3
            denoised = self.denoise(x, self.pd_val, only_first=False)
            return self.r3(x, denoised, self.dn)

6.3 mask

对应于公式5

假如有3个mask, mask_index包括随机的0,1,2
masks=0表示被mask的区域
res 表示被mask的区域设置为0

也就是n个mask的情况下,mask=0的区域占n分之一.

class MultiScaleMask(nn.Module):
    def __init__(self, scale_num: int = 2):
        super().__init__()
        self.scale_num = scale_num

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        if (len(x.shape) == 3):
            x = x.unsqueeze(0)
        b, c, h, w = x.shape
        # 0到scale_num-1 这个区间的数填充整个图像
        mask_index = torch.randint(
            0, self.scale_num, (b, 1, h, w)).expand(-1, c, -1, -1).to(x.device)
        
        res = torch.zeros(self.scale_num, *x.shape).to(x.device)
        masks = torch.BoolTensor(self.scale_num, *x.shape).to(x.device)


        for i in range(self.scale_num):
            temp = mask_index != i
            masks[i] = temp
            res[i] = x*temp
        return res, masks

默认2个mask

每个mask的shape 是 bchw

被mask的图像输入到 branch中,其实只有一个branch
两种mask的图像输入到branchnet中得到 out
然后out只取被mask的元素

def forward(self, x: torch.Tensor, return_mask: bool = False, only_first: bool = False) -> torch.Tensor:
        masked_img, masks = self.mask(x)
        dn_img = torch.zeros_like(x).to(dtype=torch.float32)
        order_len = min(self.mask_num, len(self.branches_order))
        for i, j in zip(self.branches_order, [_ for _ in range(order_len)]):
            out = self.branches[i](masked_img[j])
            dn_img[~masks[j]] = out[~masks[j]]
            # 训练的时候只使用一个mask, 反正是随机的mak
            if only_first and return_mask:
                break
        return dn_img if not return_mask else (dn_img, masks)

6.4 loss函数

训练的时候 only_first=true, 因此只有一个mask其作用,其实就是随机mask 50%像素建立损失,进行训练

但是推理的时候用到多个mask, 所有被mask的像素是组成真个图像尺寸
默认2个mask互补,降噪后的图像被mask的区域 互补 组成完整的denoised image, 参看6.3

总的来说,就是用未被mask的像素预测mask的像素。

class MaskLoss(nn.Module):
    def __init__(self, loss_type: str = 'l1') -> None:
        super().__init__()
        self.loss = losses_dict[loss_type]()

    def forward(self, input: torch.Tensor, output: torch.Tensor, masks: list[torch.Tensor]) -> torch.Tensor:
        total_loss = 0
        for mask in masks:
            total_loss += self.loss(input[~mask], output[~mask])
        return total_loss

# 只对第一个分支建立损失
class FirstBranchMaskLoss(nn.Module):
    def __init__(self, loss_type: str = 'l1') -> None:
        super().__init__()
        self.loss = losses_dict[loss_type]()

    def forward(self, input: torch.Tensor, output: torch.Tensor, masks: list[torch.Tensor]) -> torch.Tensor:
        total_loss = 0
        for mask in masks:
            total_loss += self.loss(input[~mask], output[~mask])
            break
        return total_loss

6.5 r3 增强

这个代码要求 batchsize==1

class R3(nn.Module):
    def __init__(self, r3: float = -1, r3_num: int = 8):
        super().__init__()
        self.r3 = r3
        self.r3_num = r3_num
        if r3 <= 0:
            self.enhance = AsymMaskEnhance()

    def forward(self, x: torch.Tensor, denoised: torch.Tensor, net: nn.Module) -> torch.Tensor:
        if self.r3 > 0:
            return util.r3(x, denoised, net, self.r3, self.r3_num)
        else:
            return denoised
            # return self.enhance(x, denoised, net)
def r3(x: torch.Tensor, denoised: torch.Tensor, net: nn.Module, r3_factor: float = 0.16, r3_num: int = 8,
       p: int = 0) -> torch.Tensor:
    """random replacement Refinement with ratio r3_factor

    Note: 
        This module is only used in eval, not in train. val will take r3_num times longer than train.
    Args:
        x(torch.Tensor): input tensor.BCHW,B=1
        net(nn.Module): model to eval
        r3_factor (float, optional): the ratio of radnom replace. Defaults to 0.16.
        r3_num (int, optional): the number of r3 times. Defaults to 8.
    Output: BCHW
    """
    b, c, h, w = x.shape
    temp_input = denoised.expand(r3_num, -1, -1, -1)
    x = x.expand(r3_num, -1, -1, -1).to(dtype=torch.float32)
    indices = torch.zeros(r3_num, c, h, w, dtype=torch.bool, device=x.device)
    for t in range(r3_num):
        indices[t] = (torch.rand(1, h, w) < r3_factor).repeat(3, 1, 1)
    # 16%的像素 denoised 被替换为 x
    
    temp_input = temp_input.clone()
    temp_input[indices] = x[indices]
    temp_input = F.pad(temp_input, (p, p, p, p), mode='reflect')

    # 然后输入到net,再平均。
    with torch.no_grad():
        if p == 0:
            denoised = net(temp_input)
        else:
            denoised = net(temp_input)[:, :, p:-p, p:-p]
    return torch.mean(denoised, dim=0).unsqueeze(0)

6.6 smooth增强

在这里插入图片描述

class TVLoss(torch.nn.L1Loss):
    """Weighted TV loss.

    Args:
        reduction (str): Loss method. Default: mean.
    """

    def __init__(self, reduction='mean'):
        if reduction not in ['mean', 'sum']:
            raise ValueError(
                f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')
        super(TVLoss, self).__init__(reduction=reduction)

    def forward(self, pred):
        y_diff = super().forward(
            pred[:, :, :-1, :], pred[:, :, 1:, :])
        x_diff = super().forward(
            pred[:, :, :, :-1], pred[:, :, :, 1:])

        loss = x_diff + y_diff

        return loss
  • 8
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值