Interactive Monte Carlo denoising using affinity of neural features & 关于torch.unfold的思考

Interactive Monte Carlo denoising using affinity of neural features & 关于torch.unfold的思考

在鹅厂实习摸鱼的日子里,遇到了一个有意思的复现问题。Interactive Monte Carlo denoising using affinity of neural features (SIGGRAPH 2021) 这篇文章提到了一种空洞卷积算法。

问题描述

在这里插入图片描述
结构图中前半部分为Unet结构,后半部分为作者设计的算法。后者由多层dilated conv和最后的warp previous output两部分组成。

作者取用Unet的输出参数作为卷积的权重实现了dilated conv,具体选取参数为:
( f x y t k , a x y t k , c x y t k ) , b x y t , λ x y t = UNet ⁡ ( I n p u t )             ( 1 ) \left(\mathbf{f}_{x y t}^{k}, a_{x y t}^{k}, c_{x y t}^{k}\right), b_{x y t}, \lambda_{x y t}=\operatorname{UNet}(Input) \ \ \ \ \ \ \ \ \ \ \ (1) (fxytk,axytk,cxytk),bxyt,λxyt=UNet(Input)           (1)

x , y , t x,y,t x,y,t 确定实时渲染中每幅图像的像素坐标,k为卷积层的序数,令 f x y t k \mathbf{f}_{x y t}^{k} fxytk 为输出特征, a x y t k , c x y t k , b x y t , λ x y t a_{x y t}^{k}, c_{x y t}^{k}, b_{x y t}, \lambda_{x y t} axytk,cxytk,bxyt,λxyt 为从UNet的输出得到的亲和参数,则可以定义一种权重 w x y u v t k w_{x y u v t}^{k} wxyuvtk 的计算方式:

w x y u v t k = { c x y t k  if  x = u  and  y = v exp ⁡ ( − a x y t k ∥ f x y t k − f u v t k ∥ 2 2 )  otherwise              ( 2 ) w_{x y u v t}^{k}= \begin{cases}c_{x y t}^{k} & \text { if } x=u \text { and } y=v \\ \exp \left(-a_{x y t}^{k}\left\|\mathbf{f}_{x y t}^{k}-\mathbf{f}_{u v t}^{k}\right\|_{2}^{2}\right) & \text { otherwise }\end{cases} \ \ \ \ \ \ \ \ \ \ \ (2) wxyuvtk={cxytkexp(axytkfxytkfuvtk22) if x=u and y=v otherwise            (2)

这里 c x y t k c_{x y t}^{k} cxytk充当替换矩阵对角线元素的作用,如果没有进行替换操作则对角线元素恒为常量1。最后实验的时候,去掉替换 c x y t k c_{x y t}^{k} cxytk这一步骤影响不是很大,当对角线元素恒为常量时,周围像素可能会根据常量来自适应调整取值范围。

w x y u v t k w_{x y u v t}^{k} wxyuvtk 作为卷积权重, ϵ \epsilon ϵ 为极小的扰动,空洞卷积后得到图像 L x y t ( k ) \mathbf{L}_{x y t}^{(k)} Lxyt(k) 为:

L x y t ( k ) = ∑ u , v w x y u v k L u v t ( k − 1 ) ϵ + ∑ u , v w x y u v k             ( 3 ) \mathbf{L}_{x y t}^{(k)}=\frac{\sum_{u, v} w_{x y u v}^{k} \mathbf{L}_{u v t}^{(k-1)}}{\epsilon+\sum_{u, v} w_{x y u v}^{k}} \ \ \ \ \ \ \ \ \ \ \ (3) Lxyt(k)=ϵ+u,vwxyuvku,vwxyuvkLuvt(k1)           (3)

Warp previous output的计算流程为先计算出亲和权重 ω x y u v t \omega_{x y u v t} ωxyuvt ,再根据亲和权重得到warp previous output后的最终图像 O x y t \mathbf{O}_{x y t} Oxyt
ω x y u v t = exp ⁡ ( − b x y t ∥ f x y t K − W t f u v , t − 1 K ∥ 2 2 )             ( 4 ) \omega_{x y u v t}=\exp \left(-b_{x y t}\left\|\mathbf{f}_{x y t}^{K}-\mathcal{W}_{t} \mathbf{f}_{u v, t-1}^{K}\right\|_{2}^{2}\right)\ \ \ \ \ \ \ \ \ \ \ (4) ωxyuvt=exp(bxytfxytKWtfuv,t1K22)           (4)

O x y t = ∑ u , v w x y u v K L u v t ( K − 1 ) + ∑ u ′ , v ′ ω x y u ′ v ′ W t O u ′ v ′ , t − 1 ϵ + ∑ u , v w x y u v K + ∑ u ′ , v ′ ω x y u ′ v ′             ( 5 ) \mathbf{O}_{x y t}=\frac{\sum_{u, v} w_{x y u v}^{K} \mathbf{L}_{u v t}^{(K-1)}+\sum_{u^{\prime}, v^{\prime}} \omega_{x y u^{\prime} v^{\prime}} \mathcal{W}_{t} \mathbf{O}_{u^{\prime} v^{\prime}, t-1}}{\epsilon+\sum_{u, v} w_{x y u v}^{K}+\sum_{u^{\prime}, v^{\prime}} \omega_{x y u^{\prime} v^{\prime}}}\ \ \ \ \ \ \ \ \ \ \ (5) Oxyt=ϵ+u,vwxyuvK+u,vωxyuvu,vwxyuvKLuvt(K1)+u,vωxyuvWtOuv,t1           (5)

实际上,式(2)(3)和式(4)(5)这两组的算法是完全一致的,都是由权重和输入图像循环求和后除以权重求和进行归一化。后者只是加上了warp操作。

Version 1.0 手写实现

空洞卷积与普通卷积的复杂度相同,那么作者提出的这种算法计算时间一定是可以接受的。为了更清晰地调整式(2)(3)中每一步运算,我首先尝试了使用这样一个五重循环Orz:

def dilated_conv(rx, ry, px, py, h, w, stride, lxyt, wxyuvt, f, i, a, c):
    for l1 in range(rx, h - rx):
        for l2 in range(ry, w - ry):
            for k1 in range(l1 - rx, l1 + rx + 1, stride):
                for k2 in range(l2 - ry, l2 + ry + 1, stride):
                    # print(l1,l2,k1,k2)
                    if l1 == k1 and l2 == k2:
                        wxyuvt[i][..., l1, l2, k1 - (l1 - rx), k2 - (l2 - ry)] = c[..., l1 - rx, l2 - ry].unsqueeze(1)
                    else:
                        wxyuvt[i][..., l1 - rx, l2 - ry, k1 - (l1 - rx), k2 - (l2 - ry)] = (
                                -a[..., l1 - rx, l2 - ry] * torch.exp(
                            torch.sum((f[..., l1, l2] - f[..., k1, k2]) ** 2, dim=1))).unsqueeze(1)
            if i != 0:
                lxyt[i][..., l1 - px, l2 - py] = torch.sum(
                    lxyt[i - 1][..., l1 - px:l1 + py + 1, l2 - py:l2 + py + 1] * wxyuvt[i][:, :, l1 - px, l2 - py, ...],
                    dim=(2, 3)) / (1e-10 + torch.sum(wxyuvt[i][:, :, l1 - rx, l2 - ry, ...], dim=(2, 3)))

核心想法是尽可能地整块操作数组。矩阵操作每往下指派一维,就要多写一个for循环,所以我尽可能多地用了切片,尽量减小向下的划分。但后来实测这样的计算根本无法进行···尝试了numba库加速,但是也还是很慢,一个epoch的计算需要几个小时的时间开销,于是就有了利用工具函数的想法。

Version 2.0 Pytorch实现

计算开销过大是因为没有合理地利用gpu进行加速。如果能使用一些pytorch的内置函数替代自己手写的算法的话,训练速度应该会提升不少。查阅了一些资料后,我使用了torch.unfold来完成目标。

式(2)(3)的实现:

stride = (2 ** i)
px = kernel_height // 2
py = kernel_width // 2
rx = ((kernel_height - 1) * stride + 1) // 2
ry = ((kernel_width - 1) * stride + 1) // 2
f_ori = kernels_param[:, (i * 10):(i * 10 + 8), :, :]
a = torch.square(kernels_param[:, (i * 10 + 8), :, :])
a = a.unsqueeze(-1).permute(0, 3, 1, 2)

ta += a
c = self.sigmoid(kernels_param[:, (i * 10 + 9), :, :])
# print('a',a.shape)
# print('c',c.shape)

# fxyt expand
fxyt[i] = (f_ori.unsqueeze(dim=4)).repeat(1, 1, 1, 1, kernel_l)
# fxyt[i] = fxyt[i].reshape([bs, -1, kernel_l, h, w])
fxyt[i] = fxyt[i].permute(0, 1, 4, 2, 3)

# fuvt expand
unfold = torch.nn.Unfold([kernel_height, kernel_width], stride=1, dilation=stride, padding=rx)
fuvt[i] = unfold(f_ori)
# fuvt[i] = fuvt[i].view(bs, -1, kernel_height, kernel_width, h, w)
fuvt[i] = fuvt[i].view(bs, 8, kernel_l, h, w)

# eq(4)
wxyuvt[i] = torch.exp(-a * torch.sum((fxyt[i] - fuvt[i]) ** 2, dim=1))
wxyuvt[i] = wxyuvt[i].unsqueeze(-1).permute(0, 4, 2, 3, 1)

c = c.view(-1)
index = (
    torch.LongTensor([(0 if i < h * w else 1) for i in range(h * w * bs)]),
    torch.LongTensor([0 for i in range(h * w * bs)]),
    torch.LongTensor([i % h for i in range(h * w * bs)]),
    torch.LongTensor([i % w for i in range(h * w * bs)]),
    torch.LongTensor([(kernel_l // 2) for i in range(h * w * bs)])
)
wxyuvt[i] = wxyuvt[i].index_put(index, c)

式(4)(5)的实现:

 lxyt[i] = unfold(lxyt[i - 1])
 lxyt[i] = lxyt[i].view(bs, C, kernel_l, h, w)
 # print(lxyt[i].shape)
 lxyt[i] = lxyt[i].permute(0, 1, 3, 4, 2)
 # print('self.previous_fk',self.previous_fk.shape)
 # print('mv',mv.shape)
 warp_fuvt = unfold(self.warper2d(self.previous_fk, mv))
 # warp_fuvt = warp_fuvt.view(bs, -1, kernel_height, kernel_width, h, w)
 warp_fuvt = warp_fuvt.view(bs, 8, kernel_l, h, w)
 # print(warp_fuvt.shape)
 # print(warp_fuvt.shape)
 # print(fxyt[i].shape)
 # print(torch.sum((fxyt[i] - warp_fuvt) ** 2, dim=1).shape)
 w_xyuvt = torch.exp(-b * torch.sum((fxyt[i] - warp_fuvt) ** 2, dim=1))
 # w_xyuvt = w_xyuvt.reshape(bs, -1, h, w, kernel_l)
 # print(w_xyuvt.shape)
 w_xyuvt = w_xyuvt.unsqueeze(-1).permute(0, 4, 2, 3, 1)
 # print('w_xyuvt',w_xyuvt.shape)
 # print('lxyt[i]',lxyt[i].shape)

 previous_o = unfold(self.previous_o)
 previous_o = previous_o.view(bs, C, kernel_l, h, w)
 # print(previous_o.shape)
 previous_o = previous_o.permute(0, 1, 3, 4, 2)
 lxyt[i] = (torch.sum(lxyt[i] * wxyuvt[i], dim=4) + torch.sum(previous_o * w_xyuvt,
                                                              dim=4)) / (
                   1e-10 + torch.sum(wxyuvt[i], dim=4) + torch.sum(w_xyuvt, dim=4))

在pytorch中没有直接对空洞卷积权重进行编辑的函数,但我们可以退一步对于式(2)(3)的构成部分分别进行拆解。torch.unfold提供了一个滑窗功能,能够从指定向量中取出周围相邻的元素作为新向量的某一维元素,这些元素被堆叠起来,也就是相当于对于原向量进行了滑窗提取元素,但是是降维的。具体来说,unfold函数根据指定的kernel size滑窗提取向量,再将向量flatten,按照channel的维度依次排序。最终得到的向量维度构成方式其实为:[batch_size, -1, kernel_size,kernel_size,height,width]。在unfold操作后,我们需要使用torch.permute函数将向量的维度恢复正常。这样子,一个unfold和一个permute函数,相当于完成了卷积的第一步滑窗。

对于式(2)中的 c x y t k c_{x y t}^{k} cxytk计算操作,如果按照对角线元素逐一遍历,也需要使用bs,c,h,w次查询,带来较大的计算开销。所以使用index_put来进行规则替换。以下则是对角线元素的坐标规则化表示。

torch.LongTensor([(0 if i < h * w else 1) for i in range(h * w * bs)]),
torch.LongTensor([0 for i in range(h * w * bs)]),
torch.LongTensor([i % h for i in range(h * w * bs)]),
torch.LongTensor([i % w for i in range(h * w bs)]),
torch.LongTensor([(kernel_l // 2) for i in range(h * w * bs)]))

在计算式(2)下半部分时,对于 f u v t k \mathbf{f}_{u v t}^{k} fuvtk可以使用上述所说的unfold函数进行滑窗提取, f x y t k \mathbf{f}_{x y t}^{k} fxytk则可以使用torch.repeat将其原向量扩展为与 f u v t k \mathbf{f}_{u v t}^{k} fuvtk相同维度的向量。从复杂度的角度来考虑,这样的操作其实是使用了空间开销换时间开销。unfold滑窗+repeat将式(2)的作差运算变成了整体运算,去掉了kernel size对应的两层循环。

w_xyuvt = torch.exp(-b * torch.sum((fxyt[i] - warp_fuvt) ** 2, dim=1))

在unfold函数的帮助下,训练时间最终缩短到了40min/epoch。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值