Progressive learning

Restormer是一种用于高分辨率图像修复的高效Transformer模型,采用渐进式学习策略处理全局图像统计。网络结构包括Gated-DconvFeed-ForwardNetwork和Multi-DConvHeadTransposedSelf-Attention,实现深度和宽度的平衡,优化参数和FLOPs预算。实验表明,在相似的计算预算下,更深且更窄的网络提供更高的精度。
摘要由CSDN通过智能技术生成

learn from Restormer. Restormer: Efficient Transformer for High-Resolution Image Restoration | IEEE Conference Publication | IEEE Xplore

Progressive learning

在小的crop patches上训练Transformer模型可能无法对全局图像统计进行编码,从而在测试时对全分辨率图像的效果不佳。作者提出渐进式学习,其中网络在早期的时代在较小的图像patch上进行训练,在后期的训练时代在逐渐变大的patch上进行训练。patch大时减小batch size。

# batch: 8
# mini_batch_sizes: [8,5,4,2,1,1]  
# iters: [92000,64000,48000,36000,36000,24000]
# gt_size: 384   # Max patch size for progressive training
# gt_sizes: [128,160,192,256,320,384]
# scale = 1.
# groups = np.array([sum(iters[0:i + 1]) for i in range(0, len(iters))])
# groups: [92000, 156000, 204000, 240000, 276000, 300000]

  
j = ((current_iter > groups) != True).nonzero()[0]
if len(j) == 0:
    bs_j = len(groups) - 1
else:
    bs_j = j[0]

mini_gt_size = mini_gt_sizes[bs_j]
mini_batch_size = mini_batch_sizes[bs_j]


lq = train_data['lq']  # train_data为pytorch DataLoader返回的(b, c, h, w) tensor
gt = train_data['gt']

if mini_batch_size < batch_size:
    indices = random.sample(range(0, batch_size), k=mini_batch_size)
    lq = lq[indices]
    gt = gt[indices]

if mini_gt_size < gt_size:
    x0 = int((gt_size - mini_gt_size) * random.random())
    y0 = int((gt_size - mini_gt_size) * random.random())
    x1 = x0 + mini_gt_size
    y1 = y0 + mini_gt_size
    lq = lq[:, :, x0:x1, y0:y1]
    gt = gt[:, :, x0 * scale:x1 * scale, y0 * scale:y1 * scale]

Python nonzero(a): 返回数组a中非零元素的索引值tuple。

如上例中当current_iter=0,(current_iter > groups) != True 结果为[True, True, True, True, True, True],则nozero返回(array([0,1,2,3,4,5]),) ,nozero[0] = [0,1,2,3,4,5].

random.sample(sequence, k)  sequence: 可以是一个列表,元组,字符串,或集合

从序列sequence中选择元素的k长度的新列表。

random.random() :该方法返回一个0到1之间的随机浮动数。

Deeper or wider network

作者经过消融实验发现:similar parameters/FLOPs budget,深且transformer block 中dim小(即窄)的网络更精准,而宽且dim大的网络速度更快。

Restormer中的网络架构

1.Gated-Dconv Feed-Forward Network (GDFN)

class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()

        hidden_features = int(dim*ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1,bias=bias)

        self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
        # depth-wise conv
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x

2.Multi-DConv Head Transposed Self-Attention (MDTA) 

class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        


    def forward(self, x):
        b,c,h,w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)   
        
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out

3. Downsample 

class Downsample(nn.Module):
    def __init__(self, n_feat):
        super(Downsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)

与标准下采样方法不同,首先用conv再用pixel-unshuffle。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
基于引用和引用中的内容,Transformer可以应用于图像恢复和重建任务,如去雾和去模糊。其创新点有以下几个方面:首先,改进了Transformer的空间自注意力机制,将其替换为带有深度可分离卷积的通道自注意力;其次,在卷积前向网络中引入了gating和深度可分离卷积;最后,训练方式采用了逐渐增大输入图像尺寸的progressive learning策略。 根据引用中的内容,Transformer图像重建的网络结构主要包括以下几个部分:首先,通过3×3卷积提取低阶特征F0;然后,使用由4个阶段的Transformer构成的编码器-解码器进行上下采样,其中上采样使用pixel unshuffle,下采样使用shuffle;接下来,通过Transformer进行图像细化;最后,通过3×3卷积恢复原始通道数,并将其与原始图像的残差相加,得到重建的输出图像。 因此,Transformer图像重建的方法可以总结为:首先提取低阶特征,然后使用编码器-解码器结构进行上下采样,接着进行图像细化,最后恢复原始通道数并将其与原始图像的残差相加,得到重建的输出图像。 总结起来,Transformer在图像重建中的创新点包括改进的自注意力机制和卷积前向网络,以及采用逐渐增大输入图像尺寸的训练方式。其网络结构包括特征提取、编码器-解码器结构、图像细化和通道恢复等步骤。以上是基于所提供的引用内容给出的关于Transformer图像重建的回答。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值