论文笔记:Partial Convolutions局部卷积——Image Inpainting for Irregular Holes Using Partial Convolutions

论文笔记:Partial Convolutions局部卷积——Image Inpainting for Irregular Holes Using Partial Convolutions

综述

论文题目:《Image Inpainting for Irregular Holes Using Partial Convolutions》
会议时间:European Conference on Computer Vision 2018 (ECCV 18)

论文地址:https://openaccess.thecvf.com/content_ECCV_2018/papers/Guilin_Liu_Image_Inpainting_for_ECCV_2018_paper.pdf

源码地址:https://github.com/NVIDIA/partialconv

针对领域:图像修补

方法

  特征更新规则:不规则图形的卷积(图像非矩形)过程中,假设 W W W表示卷积核权重参数, b b b表示相应的偏置参数, X X X表示当前卷积核对应滑动窗口中的像素点特征, M M M表示像素点对应的二元掩模(利用 M M M确定不规则图形),经过局部卷积操作后的特征数据可以表示为:
x ′ = { W T ( X ⊙ M ) s u m ( 1 ) s u m ( M ) + b i f s u m ( M ) > 0 0 otherwise \begin{aligned} x'= \left\{\begin{matrix} &W^T(X\odot M)\frac{sum(1)}{sum(M)}+b\quad &if\quad sum(M)>0\\ &0\quad&\text{otherwise} \end{matrix}\right. \end{aligned} x={WT(XM)sum(M)sum(1)+b0ifsum(M)>0otherwise
其中 ⊙ \odot 表示点乘操作, s u m ( 1 ) sum(1) sum(1)表示滑动窗口对应的像素点数量,即卷积核大小。如公式所示,经过局部卷积得到的特征数据只与未屏蔽的输入数据有关,即只与掩模图 M M M上数据为 1 1 1对应位置的像素点有关。其中 s u m ( 1 ) s u m ( M ) \frac{sum(1)}{sum(M)} sum(M)sum(1)表示缩放因子,利用适当的缩放来调整有效输入的变化量,当一个像素点为有效像素点(即 M M M对应位置上为 1 1 1),但周围的有效像素点比较少时( M M M对应位置上为 0 0 0),经过卷积后该位置上的数会比较小(卷积是局部像素点与权重乘积再相加,如果周围有效像素点比较少,即0比较多,则该像素点卷积后的数比较小),此时该因子会比较大,通过放大卷积后的数据来平衡特征大小。

  掩模更新规则:如果当前像素点对应的滑动窗口中至少包含一个有效像素点,则将该位置标记为有效,即该像素点处的 M M M设置为 1 1 1,这有点类似于图像处理中的膨胀运算:
m ′ = { 1 , i f s u m ( M ) > 0 0 , otherwise \begin{aligned} m'= \left\{\begin{matrix} 1,\quad &if\quad sum(M)>0\\ 0,\quad&\text{otherwise} \end{matrix}\right. \end{aligned} m={1,0,ifsum(M)>0otherwise
该运算可以嵌入到前向传播过程中,只要输入的图像包含有效像素点,则经过足够多的部分卷积运算,掩模图 M M M上的数据最终会全部更新为1。

代码实现

以二维卷积为例

class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False
        # 设置为True表示同时返回卷积后的特征和更新后的掩模M
        self.return_mask = True
        # 这里的继承包括继承父类nn.Conv2d中,__init__函数里定义的所有属性,包括前向传播用到的偏置项
        super(PartialConv2d, self).__init__(*args, **kwargs)
        if self.multi_channel:
            self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0],
                                                 self.kernel_size[1])
        else:
            self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1])
        # 滑动窗口的尺寸,即长乘宽,对应论文中的sum(1)
        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * \
                             self.weight_maskUpdater.shape[3]

        self.last_size = (None, None, None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)
            # 以下操作不产生梯度
            with torch.no_grad():
                # 保证mask卷积参数与输入图像的数据类型一致
                if self.weight_maskUpdater.type() != input.type():
                    self.weight_maskUpdater = self.weight_maskUpdater.to(input)
                # 如果不存在掩模,则创建一个全一数组当做掩模图
                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2],
                                          input.data.shape[3]).to(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
                else:
                    mask = mask_in
                # 先对掩模做卷积,得到论文中的sum(M)
                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride,
                                            padding=self.padding, dilation=self.dilation, groups=1)
                # 这里得到掩模比率,用于调整有效输入的特征大小,相当于论文中的sum(1)/sum(M)
                self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-8)
                # 通过限制最大为1,最小为0,从而得到论文中的m'
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                # 更新后的掩模图与之前得到的比率相乘,使sum(M)<0时的比率变为0,用于后续和卷积后的图像做点乘,得到x'
                # 即对应公式(1)中,sum(M)<0的像素点经过局部卷积后像素点依然是0
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
        # 这里执行PartialConv2d父类的前向传播操作,而父类为nn.Conv2d,因此相当于直接做一个卷积运算
        # 如果存在掩模M,则先让输入图像与掩模M做点乘,之后做卷积运算,否则直接让输入图像做卷积运算
        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)
        # 因为这个类继承nn.Conv2d,因此这里的self.bias是nn.Conv2d中自带的偏置项
        if self.bias is not None:
            # 如果该卷积操作存在偏置项的话,则先让卷积操作减去偏置项,之后再与掩模比率相乘,最后再加上偏置项
            # 对应论文中公式1,先让权重乘以输入数据,之后乘以掩模比率,最后加上偏置项
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            # 再让输出的数据与更新后的掩模图相乘(相当于再减去sum(M)<0处的偏置数据)
            output = torch.mul(output, self.update_mask)
        else:
            # 如果没有偏置项的话,则直接做点乘即可
            output = torch.mul(raw_out, self.mask_ratio)
        # 返回输出数据和更新后的掩模图
        if self.return_mask:
            return output, self.update_mask
        else:
            return output

注:以上内容仅是笔者的个人观点,若有错误,欢迎指正。

  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
C语言中没有STL库,也没有类似于C++中的std::partial_sort的内置函数。不过,我们可以使用快速排序算法来实现类似的功能。 快速排序算法的基本思想是选择一个基准元素,将数组分成两部分,一部分小于等于基准元素,另一部分大于基准元素。然后递归地对两部分分别进行快速排序,最终得到一个有序数组。 实现std::partial_sort的步骤如下: 1. 取数组中前k个元素作为基准元素,对其进行快速排序。 2. 遍历数组中第k个元素到最后一个元素,并将其与基准元素进行比较,如果小于等于基准元素,则将其插入到前k个元素中,并对前k个元素进行快速排序。 3. 最终得到前k个元素是数组中最小的k个元素,但它们并不一定是有序的。可以使用任意排序算法对其进行排序。 下面是一个示例代码: ```c #include <stdio.h> void quick_sort(int arr[], int left, int right) { if (left >= right) return; int i = left, j = right; int pivot = arr[left]; while (i < j) { while (i < j && arr[j] > pivot) j--; if (i < j) arr[i++] = arr[j]; while (i < j && arr[i] <= pivot) i++; if (i < j) arr[j--] = arr[i]; } arr[i] = pivot; quick_sort(arr, left, i - 1); quick_sort(arr, i + 1, right); } void partial_sort(int arr[], int n, int k) { quick_sort(arr, 0, k - 1); for (int i = k; i < n; i++) { if (arr[i] <= arr[k - 1]) { arr[k - 1] = arr[i]; quick_sort(arr, 0, k - 1); } } quick_sort(arr, 0, k - 1); } int main() { int arr[] = {1, 3, 2, 5, 4}; int n = sizeof(arr) / sizeof(arr[0]); int k = 3; partial_sort(arr, n, k); for (int i = 0; i < k; i++) { printf("%d ", arr[i]); } printf("\n"); return 0; } ``` 该代码实现了类似于std::partial_sort的功能,输出结果为1 2 3,即数组中最小的3个元素。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

视觉萌新、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值