CVPR LIIF超分辨率代码阅读

Learning Continuous Image Representation with Local Implicit Image Function


paper
code


abstract

物理世界以连续的方式呈现视觉图像,但计算机以离散2D像素数组的方式存储和显示图像。此文学习图像的连续表示,使用局部隐式图像函数(Local Implicit Image Function,LIIF)将图像坐标和坐标周围的2D深度特征作为输入,预测输出给定坐标下的RGB值。通过自监督超分辨率任务来训练一个编码器和LIIF表示来生成像素图像的连续表示,可以做到任意倍数的分辨率,甚至可以推算不在训练任务中的30倍以上超分。通过将图像模型化为一个在连续域中的函数,可以恢复和生成任意分辨率的图像。隐式函数的思想是将一个对象表示为一个函数,将坐标映射到相应的信号(如3D对象表面的符号距离,图像中的RGB值)。神经隐式函数采用深度神经网络参数化。为了跨实例共享知识,而不是为每个对象拟合单独的隐式函数,提出了基于编码器的方法来预测每个对象的潜在编码。然后隐式函数由所有对象共享,同时它将潜在代码作为额外的输入。
LIIF

Local Implicit Image Function

在LIIF表示中,每个连续图像 I ( i ) I^{(i)} I(i)由二维特征映射 M ( i ) ∈ R H × W × D M^{(i)}∈\R^{H \times W \times D} M(i)RH×W×D表示。 一个神经隐式函数 f θ f_θ fθ(以 θ θ θ为其参数)被所有图像共享,它被参数化为 M L P MLP MLP并采取 s = f ( z , x ) s = f(z,x) s=f(z,x)(简便省略 θ θ θ)形式,其中 z z z是一个向量, x ∈ X x \in X xX是连续图像域中的二维坐标, s ∈ S s \in S sS是预测信号(即RGB值)。

对于定义的 f f f,每个向量 z z z都可以看作是表示函数 f ( z , ⋅ ) : X → S f(z,·):X→S f(z,):XS f ( z , ⋅ ) f(z,·) f(z,)可以看作是一个连续的图像,即映射坐标到RGB值的函数。假设 M ( i ) M^{(i)} M(i) H × W H \times W H×W特征向量(称为隐码latent codes)均匀分布在 I ( i ) I^{(i)} I(i)的连续图像域的2D空间中,并为它们中的每一个分配一个2D坐标。

对于图像 I ( i ) I^{(i)} I(i),坐标 x q x_q xq处的RGB值定义为 I ( i ) ( x q ) = f ( z ∗ , x q − v ∗ ) I^{(i)}(x_q) = f(z^*,x_q-v^*) I(i)(xq)=f(z,xqv),其中 z ∗ z^∗ z M ( i ) M^{(i)} M(i)中与 x q x_q xq最近的(欧几里德距离)隐码, v ∗ v^∗ v是图像域中潜码 z ∗ z^∗ z的坐标。 例如 z 11 ∗ z^∗_{11} z11是当前定义中 x q x_q xq z ∗ z^∗ z,而 v ∗ v^∗ v被定义为 z 11 ∗ z^∗_{11} z11的坐标。
liif
在所有图像共享的隐式函数 f f f下,连续图像由二维特征映射 M ( i ) ∈ R H × W × D M^{(i)} \in \R^{H \times W \times D} M(i)RH×W×D表示,该特征映射被看作是在2D域中均匀分布的 H × W H×W H×W隐码。 在 M ( i ) M^{(i)} M(i)中的每个潜在码 z z z表示连续图像的局部部分,负责预测与它最近的坐标集的信号。

从图像得到归一化坐标值和RGB值

def make_coord(shape, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
    return ret
    
coord = make_coord((h, w)) #h,w为SR目标的高宽

def to_pixel_samples(img):
    """ Convert the image to coord-RGB pairs.
        img: Tensor, (3, H, W)
    """
    coord = make_coord(img.shape[-2:])   #(h*w,2)--(h*w,[x,y])
    rgb = img.view(3, -1).permute(1, 0)  #(h*w,3)--(h*w,[R,G,B])
    return coord, rgb

Feature unfolding

为了丰富隐码包含的信息,对特征 M ( i ) M^{(i)} M(i)展开得到 M ^ ( i ) {\hat M^{(i)}} M^(i) M ^ ( i ) {\hat M^{(i)}} M^(i)是在 M ( i ) M^{(i)} M(i) 3 × 3 3 \times 3 3×3相邻隐码的合并。
M ^ j k ( i ) = C o n c a t ( { M j + l , k + m ( i ) } l , m ∈ { − 1 , 0 , 1 } ) {\hat M^{(i)}_{jk}} =Concat(\{M^{(i)}_{j+l,k+m}\}_{l,m\in\{-1,0,1\}}) M^jk(i)=Concat({Mj+l,k+m(i)}l,m{1,0,1})
C o n c a t Concat Concat指的是一组向量的连接时, M ( i ) M^{(i)} M(i)在其边界外被零向量填充。
[ N , C , L R H , L R W ] [N,C,LR_H,LR_W] [N,C,LRH,LRW] f e a t feat feat被以下 u n f o l d unfold unfold后变为 [ N , C ∗ 3 ∗ 3 , L R H , L R W ] [N,C*3*3,LR_H,LR_W] [N,C33,LRH,LRW]

feat = F.unfold(feat, 3, padding=1).view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])

Local ensemble

s = f ( z , x ) s = f(z,x) s=f(z,x)是一个不连续预测,由于 x q x_q xq的信号预测是通过查询 M ( i ) M^{(i)} M(i)中最近的隐码 z ∗ z^∗ z完成的,所以当 x q x_q xq在图像域中移动时, z ∗ z^∗ z会突然从一个隐码切换到另一个隐码。在 z ∗ z^∗ z选择切换的那些坐标周围,两个无限接近坐标的信号将从不同的隐码中预测出来,只要学习的隐式函数 f f f不是完美的,在 z ∗ z^∗ z选择切换的边界处没出现不连续的图形。为了解决这个问题,使用局部集成技术,扩大每个隐码的表示
I ( i ) ( x q ) = ∑ t ∈ { 00 , 01 , 10 , 11 } S t S f ( z t ∗ , x q − v t ∗ ) I^{(i)}(x_q) =\sum_{t\in\{00,01,10,11\}} \frac {S_t}{S} f(z^*_t,x_q-v^*_t) I(i)(xq)=t{00,01,10,11}SStf(zt,xqvt)
z t ∗ ( t ∈ { 00 , 01 , 10 , 11 } ) z^*_t(t\in\{00,01,10,11\}) zt(t{00,01,10,11})指左上、右上,左下,右下子空间中最近的隐码, v t ∗ v^*_t vt z t ∗ z^*_t zt的坐标, S t S_t St x q x_q xq v t ′ ∗ v^*_{t'} vt v t ′ ∗ v^*_{t'} vt v t ∗ v^*_{t} vt的对角,如00对11,10对01)之间的矩形面积。权重由 S = ∑ t S t S=\sum_tS_t S=tSt归一化。特征图 M ( i ) M^{(i)} M(i)在边界外是镜像填充的,因此这也适用于边界附近的坐标。

这是为了让由隐码表示的局部图像块与其相邻块重叠,使得在每个坐标处有四个隐码用于独立预测信号。然后,这四个预测通过用归一化置信度投票来合并,归一化置信度与查询点和其最近的隐码对角对应点之间的矩形面积成比例,因此当查询坐标更近时,置信度变得更高。通过这种投票,它在 z ∗ z^* z转换坐标(即图中的虚线)处实现了连续过渡。

vx_lst = [-1, 1]
vy_lst = [-1, 1]
eps_shift = 1e-6

rx = 2 / feat.shape[-2] / 2  #2/H/2
ry = 2 / feat.shape[-1] / 2  #2/W/2

feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() #[LR_H,LR_W,2]
feat_coord = feat_coord.permute(2, 0, 1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])#[N,2,LR_H,LR_W]

preds = []
areas = []
for vx in vx_lst:
    for vy in vy_lst:
        coord_ = coord.clone()#[N,SR_H*SR_W,2]
        coord_[:, :, 0] += vx * rx + eps_shift
        coord_[:, :, 1] += vy * ry + eps_shift
        coord_.clamp_(-1 + 1e-6, 1 - 1e-6)

        q_feat = F.grid_sample(feat, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,C*9,1,SR_H*SR_W]
        q_feat = q_feat[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,C*9]

        q_coord = F.grid_sample(feat_coord, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,2,1,SR_H*SR_W]
        q_coord = q_coord[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,2]

        rel_coord = coord - q_coord #[N,SR_H*SR_W,2]
        rel_coord[:, :, 0] *= feat.shape[-2]
        rel_coord[:, :, 1] *= feat.shape[-1]
        inp = torch.cat([q_feat, rel_coord], dim=-1) #[N,SR_H*SR_W,C*9+2]

        if self.cell_decode:
            rel_cell = cell.clone()
            rel_cell[:, :, 0] *= feat.shape[-2]
            rel_cell[:, :, 1] *= feat.shape[-1]
            inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]

        bs, q = coord.shape[:2] #bs=N q=SR_H*SR_W
        #[N*SR_H*SR_W,C*9+2+2] --> [N*SR_H*SR_W,3]
        pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1) #[N,SR_H*SR_W,3]
        preds.append(pred) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]

        area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
        areas.append(area + 1e-9) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]


tot_area = torch.stack(areas).sum(dim=0) #[N,SR_H*SR_W]
if self.local_ensemble:
    t = areas[0]; areas[0] = areas[3]; areas[3] = t #swap(areas[0],areas[3])
    t = areas[1]; areas[1] = areas[2]; areas[2] = t #swap(areas[1],areas[2])

Cell decoding

为了LIIF能够表示基于像素形式的任意分辨率呈现,假设给定了所需分辨率,一种简单方法是查询连续表示 I ( ∗ ) I^{(*)} I()中像素中心坐标处的RGB值,但因为查询像素的预测RGB值与其大小无关,其像素区域中的信息除了中心值都被丢弃,可能不是最佳的。
s = f c e l l ( z , [ x , c ] ) s=f_{cell}(z,[x,c]) s=fcell(z,[x,c])
c = [ c h , c w ] c=[c_h,c_w] c=[ch,cw]包含指定查询像素的高度和宽度两个值, [ x , c ] [x,c] [x,c]是值 x x x c c c的连接(concatenation), c c c是附加输入。
f c e l l ( z , [ x , c ] ) f_{cell}(z,[x,c]) fcell(z,[x,c])能理解为使用形状 c c c渲染以坐标 x x x为中心的像素的RGB值。对于 64 × 64 64\times64 64×64的分辨率, c c c是图像宽度的 1 / 64 1/64 1/64。逻辑上,当 c → 0 c→0 c0时, f c e l l ( z , x ) = f c e l l ( z , [ x , c ] ) f_{cell}(z,x) =f_{cell}(z,[x,c]) fcell(z,x)=fcell(z,[x,c]),即连续图像可以看作像素无限小的图像。

cell = torch.ones_like(coord) #[SR_H*SR_W,2] [1*2/SR_H,1*2/SR_W]
cell[:, 0] *= 2 / h
cell[:, 1] *= 2 / w 

if self.cell_decode:
     rel_cell = cell.clone()
     rel_cell[:, :, 0] *= feat.shape[-2]
     rel_cell[:, :, 1] *= feat.shape[-1]
     inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]

LIIF class 完全代码

class LIIF(nn.Module):
    def __init__(self, encoder_spec, imnet_spec=None,
                 local_ensemble=True, feat_unfold=True, cell_decode=True):
        super().__init__()
        self.local_ensemble = local_ensemble
        self.feat_unfold = feat_unfold
        self.cell_decode = cell_decode
        self.encoder = models.make(encoder_spec)

        #print("self.encoder.out_dim",self.encoder.out_dim)
        if imnet_spec is not None:
            imnet_in_dim = self.encoder.out_dim     #64
            if self.feat_unfold:
                imnet_in_dim *= 9
            imnet_in_dim += 2 # attach coord 指定查询像素的坐标 [x,y]
            if self.cell_decode:
                imnet_in_dim += 2 #[Cell_h, Cell_w]指定查询像素的高度和宽度的两个值
            self.imnet = models.make(imnet_spec, args={'in_dim': imnet_in_dim})
        else:
            self.imnet = None

    def gen_feat(self, inp):
        self.feat = self.encoder(inp)
        return self.feat

    def query_rgb(self, coord, cell=None):
        #coord [N,SR_H*SR_*W,2]
        #cell [N,SR_H*SR_*W,2]
        feat = self.feat #[N,C,LR_H,LR_W]

        if self.imnet is None:
            ret = F.grid_sample(feat, coord.flip(-1).unsqueeze(1), mode='nearest', align_corners=False)
            ret = ret[:, :, 0, :].permute(0, 2, 1)
            return ret

        if self.feat_unfold:
            # [N,C*3*3,H,W]
            feat = F.unfold(feat, 3, padding=1).view(feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])

        if self.local_ensemble:
            vx_lst = [-1, 1]
            vy_lst = [-1, 1]
            eps_shift = 1e-6
        else:
            vx_lst, vy_lst, eps_shift = [0], [0], 0

        # field radius (global: [-1, 1])
        rx = 2 / feat.shape[-2] / 2  #2/H/2
        ry = 2 / feat.shape[-1] / 2  #2/W/2

        feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() #[LR_H,LR_W,2]
        feat_coord = feat_coord.permute(2, 0, 1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])#[N,2,LR_H,LR_W]

        preds = []
        areas = []
        for vx in vx_lst:
            for vy in vy_lst:
                coord_ = coord.clone()#[N,SR_H*SR_W,2]
                coord_[:, :, 0] += vx * rx + eps_shift
                coord_[:, :, 1] += vy * ry + eps_shift
                coord_.clamp_(-1 + 1e-6, 1 - 1e-6)

                q_feat = F.grid_sample(feat, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,C*9,1,SR_H*SR_W]
                q_feat = q_feat[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,C*9]

                q_coord = F.grid_sample(feat_coord, coord_.flip(-1).unsqueeze(1),mode='nearest', align_corners=False)#[N,2,1,SR_H*SR_W]
                q_coord = q_coord[:, :, 0, :].permute(0, 2, 1)#[N,SR_H*SR_W,2]

                rel_coord = coord - q_coord #[N,SR_H*SR_W,2]
                rel_coord[:, :, 0] *= feat.shape[-2]
                rel_coord[:, :, 1] *= feat.shape[-1]
                inp = torch.cat([q_feat, rel_coord], dim=-1) #[N,SR_H*SR_W,C*9+2]

                if self.cell_decode:
                    rel_cell = cell.clone()
                    rel_cell[:, :, 0] *= feat.shape[-2]
                    rel_cell[:, :, 1] *= feat.shape[-1]
                    inp = torch.cat([inp, rel_cell], dim=-1) #[N,SR_H*SR_W,C*9+2+2]

                bs, q = coord.shape[:2] #bs=N q=SR_H*SR_W
                #[N*SR_H*SR_W,C*9+2+2] --> [N*SR_H*SR_W,3]
                pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1) #[N,SR_H*SR_W,3]
                preds.append(pred) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]

                area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
                areas.append(area + 1e-9) #[[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W],[N,SR_H*SR_W]]


        tot_area = torch.stack(areas).sum(dim=0) #[N,SR_H*SR_W]
        if self.local_ensemble:
            t = areas[0]; areas[0] = areas[3]; areas[3] = t #swap(areas[0],areas[3])
            t = areas[1]; areas[1] = areas[2]; areas[2] = t #swap(areas[1],areas[2])
        ret = 0
        for pred, area in zip(preds, areas):
            ret = ret + pred * (area / tot_area).unsqueeze(-1)
        return ret

    def forward(self, inp, coord, cell):
        self.gen_feat(inp)
        return self.query_rgb(coord, cell)
  • 9
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值