Restormer

## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super(OverlapPatchEmbed, self).__init__()

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, x):
        x = self.proj(x)

        return x

OverlapPatchEmbed 将输入图像传递给一个3x3的卷积层,得到一个嵌入结果。由于其步长为1且有填充,该卷积操作保持了输入图像的空间尺寸。这与许多其他的patch embedding方法不同,这些方法通常使用较大的卷积核和步长来直接将图像分割成不重取的patch。这里的方法允许邻近的patch有重叠,因此得名 “OverlapPatchEmbed”

class Downsample(nn.Module):
    def __init__(self, n_feat):
        super(Downsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)

Downsample 模块首先使用卷积层将输入特征图的通道数减半,然后使用像素反洗牌操作进一步降采样特征图。这两步操作联合起来为输入特征图提供了降采样效果
在这里插入图片描述
这是Pytorch的官方文档,可以看到,Downsample模块先用二维卷积层nn.Conv2d把输入通道减半,然后用nn.PixelUnshuffle将通道数x4,长/2,宽/2,所以总体上该模块是将通道数x2,长/2,宽/2。

class Upsample(nn.Module):
    def __init__(self, n_feat):
        super(Upsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelShuffle(2))

    def forward(self, x):
        return self.body(x)

在这里插入图片描述
Transformer Block不改变输出维度
在这里插入图片描述
以下是Restormer的实现代码,与上图完全一致。

##---------- Restormer -----------------------
class Restormer(nn.Module):
    def __init__(self, 
        inp_channels=3, 
        out_channels=3, 
        dim = 48,
        num_blocks = [4,6,6,8], 
        num_refinement_blocks = 4,
        heads = [1,2,4,8],
        ffn_expansion_factor = 2.66,
        bias = False,
        LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'
        dual_pixel_task = False        ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
    ):

        super(Restormer, self).__init__()

        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)#3x3卷积层

        self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor\
                                                               , bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        
        self.down1_2 = Downsample(dim) ## From Level 1 to Level 2
        self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,\
                                                               bias=bias,LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        
        self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3
        self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, \
                                                               bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])

        self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
        self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, \
                                                       bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
        
        self.up4_3 = Upsample(int(dim*2**3)) ## From Level 4 to Level 3
        self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
        self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,\
                                                                bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])


        self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
        self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, \
                                                               bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        
        self.up2_1 = Upsample(int(dim*2**1))  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

        self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,\
                                                                LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        
        self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, \
                                                           LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
        
        #### For Dual-Pixel Defocus Deblurring Task ####
        self.dual_pixel_task = dual_pixel_task
        if self.dual_pixel_task:
            self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias)
        ###########################
            
        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, inp_img):

        inp_enc_level1 = self.patch_embed(inp_img)#F0
        out_enc_level1 = self.encoder_level1(inp_enc_level1)
        
        inp_enc_level2 = self.down1_2(out_enc_level1)
        out_enc_level2 = self.encoder_level2(inp_enc_level2)

        inp_enc_level3 = self.down2_3(out_enc_level2)
        out_enc_level3 = self.encoder_level3(inp_enc_level3) 

        inp_enc_level4 = self.down3_4(out_enc_level3)        
        latent = self.latent(inp_enc_level4) 
                        
        inp_dec_level3 = self.up4_3(latent)
        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        out_dec_level3 = self.decoder_level3(inp_dec_level3) 

        inp_dec_level2 = self.up3_2(out_dec_level3)
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
        out_dec_level2 = self.decoder_level2(inp_dec_level2) 

        inp_dec_level1 = self.up2_1(out_dec_level2)
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        out_dec_level1 = self.decoder_level1(inp_dec_level1)
        
        out_dec_level1 = self.refinement(out_dec_level1)

        #### For Dual-Pixel Defocus Deblurring Task ####
        if self.dual_pixel_task:
            out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
            out_dec_level1 = self.output(out_dec_level1)
        ###########################
        else:
            out_dec_level1 = self.output(out_dec_level1) + inp_img


        return out_dec_level1
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
### 回答1: restormer代码是指一种用于机器学习和数据分析的开源数据处理框架。该框架的主要目标是简化数据处理流程,提供高效、可扩展和易于使用的工具。 在restormer代码中,使用了多种技术和算法来处理不同类型的数据。它提供了一套强大的工具和库,用于数据预处理、特征工程、模型训练和评估等任务。 restormer代码采用了模块化的设计思路,可以根据需要选择不同的组件来构建自己的数据处理流程。它提供了一些常用的数据处理函数和工具,如数据清洗、缺失值填充、特征选择、特征转换等,同时还支持自定义函数和组件,方便用户根据自己的需求进行扩展和定制化。 该框架还提供了一些高级功能,如数据集划分、交叉验证、模型融合等,可以帮助用户更好地进行数据分析和建模。同时,它还提供了一些可视化工具,方便用户对数据和模型进行观察和分析。 总之,restormer代码是一种功能强大、灵活可扩展的数据处理框架,可以帮助用户快速、高效地进行机器学习和数据分析任务。它的设计思路和丰富的功能使得用户可以根据自己的需求进行数据预处理和模型构建,提高工作效率和数据科学能力。 ### 回答2: Restormer是一种不断进化优化的代码,它的主要目标是提供高性能和高效率的编程解决方案。Restormer代码的核心原理是通过持续迭代和适应性改进来改进程序的性能和效率。 Restormer的核心思想是在程序运行期间不断收集性能指标和数据,然后使用这些信息来自动调整代码以达到更好的性能。它通过使用自适应技术来调整代码的各个方面,如数据结构、算法和并行处理等,以进行优化。这种自适应技术可以根据实际情况动态调整各个代码部分的执行顺序和参数设置,以获得最佳的性能结果。 Restormer代码的一个关键特点是其自学习的能力。它可以分析运行时的性能指标和数据,自动学习什么样的代码变化可以提高性能,并根据这些学习结果来优化代码。通过不断地学习和进化,Restormer代码可以在不断变化的运行环境和需求下持续提供最佳的性能。 另外,Restormer代码也注重灵活性和可扩展性。它可以适应各种不同的编程语言和框架,并且可以应用于任何规模的项目。开发人员可以根据自己的需求和特定的应用场景来使用Restormer代码。 总的来说,Restormer代码通过持续迭代和适应性改进来提供高性能和高效率的解决方案。它的自适应技术和自学习能力可以根据实际情况动态地调整代码,以实现最佳性能结果。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

果壳小旋子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值