使用Convnext升级SRGAN

最近再向着超分辨率重建的方向学习,并通过Convnext对SRGAN网络进行了一些结构上的升级,效果还不错。

代码

        带训练权重带少量数据集(414mb):

                链接:https://pan.baidu.com/s/1KYyHyE5BpCTjKNuwBh52Wg?pwd=dnmd 
                提取码:dnmd

        纯代码(5.8mb):

                链接:https://pan.baidu.com/s/18kqXsh6NnKKMNg3WHdhRUA?pwd=89ay 
                提取码:89ay

视频

SRGAN完整代码,并用Convnext进行优化_哔哩哔哩_bilibili

如下是我的使用Convnext的SRGANext与SRGAN的比较。

如下是SRGAN在第100轮训练的低分辨率图像,由低分辨率图像生成的高分辨率图像,和高分辨率图像

 

如下是SRGANext在第100轮训练的低分辨率图像,由低分辨率图像生成的高分辨率图像,和高分辨率图像

 

 

 这是SRGAN的结构

这是我的SRGANext的生成器结构

这是我的SRGANext的辨别器结构

 

 生成器和辨别器都采用的block

 附上SRGAN代码

import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    '''两个卷积,不改变大小;也不该变维度,因此残差也不需要1*1卷积核调整维度'''
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
            nn.PReLU(),
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(in_features, 0.8),
        )

    def forward(self, x):
        return x + self.conv_block(x)


class Generator(nn.Module):
    def __init__(self, scale_factor=2, num_residual_blocks=16):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=4,padding_mode='reflect', bias=True)
        self.prelu = nn.PReLU()
        self.residual_blocks = nn.Sequential(*[ResidualBlock(64) for _ in range(num_residual_blocks)])
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(64)
        self.upsample = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1,padding_mode='reflect', bias=True),
            nn.PixelShuffle(scale_factor),
            nn.PReLU(),
            nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1,padding_mode='reflect', bias=True),
            nn.PixelShuffle(scale_factor),
            nn.PReLU(),
            nn.Conv2d(64, 3, kernel_size=9, stride=1, padding=4, bias=True)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.prelu(out)
        residual = out
        out = self.residual_blocks(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.upsample(out)
        return out

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, bias=True),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1024),
            nn.Dropout(),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        out = self.layer(x)
        return out

 附上SRGANext代码

注意其中layernorm的通道维度在最后。

import torch
import torch.nn as nn

class LayerNorm(nn.Module):
    '''
        [b,c,h,w] --> permute [b,h,w,c] --> LN --> permute [b,c,h,w]
    '''
    def __init__(self,dim):
        super(LayerNorm, self).__init__()
        self.norm = nn.LayerNorm(dim)
    def forward(self,x):
        x = x.permute(0,2,3,1)
        x = self.norm(x)
        x = x.permute(0,3,1,2)
        return x





class SRGANextBlock(nn.Module):
    def __init__(self,dim, layer_scale_init_value=1e-6):
        super(SRGANextBlock, self).__init__()
        self.dconv1 = nn.Conv2d(in_channels=dim,out_channels=dim,kernel_size=7,stride=1,padding=3,groups=dim)
        self.norm1 = LayerNorm(dim)
        self.pconv1 = nn.Conv2d(dim,4*dim,kernel_size=1,stride=1)
        self.act = nn.GELU()
        self.pconv2 = nn.Conv2d(4*dim,dim,kernel_size=1,stride=1)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim,)),
                                  requires_grad=True) if layer_scale_init_value > 0 else None

    def forward(self,x):
        shortcut = x
        x = self.dconv1(x)
        x = self.norm1(x)
        x = self.pconv1(x)
        x = self.act(x)
        x = self.pconv2(x)
        x = x.permute(0, 2, 3, 1)  # [N, C, H, W] -> [N, H, W, C]
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # [N, H, W, C] -> [N, C, H, W]
        x = shortcut + x
        return x




class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.stem = nn.Conv2d(3,64,kernel_size=7,stride=1,padding=3)
        self.norm1 = LayerNorm(64)
        self.block1 = nn.Sequential(*[SRGANextBlock(64) for _ in range(3)])
        self.norm2 = LayerNorm(64)
        self.pconv1 = nn.Conv2d(64,128,kernel_size=1,stride=1)
        self.block2 = nn.Sequential(*[SRGANextBlock(128) for _ in range(3)])
        self.norm3 = LayerNorm(128)
        self.pconv2 = nn.Conv2d(128,256,kernel_size=1,stride=1)
        self.block3 = nn.Sequential(*[SRGANextBlock(256) for _ in range(9)])
        self.norm4 = LayerNorm(256)
        self.pconv3 = nn.Conv2d(256,512,kernel_size=1,stride=1)
        self.block4 = nn.Sequential(*[SRGANextBlock(512) for _ in range(3)])
        self.upsample = nn.Sequential(
            nn.Conv2d(512,1024,kernel_size=3,stride=1,padding=1),
            nn.PixelShuffle(2),
            nn.PReLU(),
            nn.Conv2d(256,512,kernel_size=3,stride=1,padding=1),
            nn.PixelShuffle(2),
            nn.PReLU()
        )
        self.conv1 = nn.Conv2d(128,3,kernel_size=7,stride=1,padding=3)
    def forward(self,x):
        x = self.stem(x)
        x = self.norm1(x)
        x = self.block1(x)
        x = self.norm2(x)
        x = self.pconv1(x)
        x = self.block2(x)
        x = self.norm3(x)
        x = self.pconv2(x)
        x = self.block3(x)
        x = self.norm4(x)
        x = self.pconv3(x)
        x = self.block4(x)
        x = self.upsample(x)
        x = self.conv1(x)
        return x



class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3,96,kernel_size=7,stride=2,padding=3),
            LayerNorm(96)
        )
        self.block1 = nn.Sequential(*[SRGANextBlock(96) for _ in range(3)])
        self.downsample1 = nn.Sequential(
            LayerNorm(96),
            nn.Conv2d(96,192,kernel_size=2,stride=2)
        )
        self.block2 = nn.Sequential(*[SRGANextBlock(192) for _ in range(3)])
        self.downsample2 = nn.Sequential(
            LayerNorm(192),
            nn.Conv2d(192, 384, kernel_size=2, stride=2)
        )
        self.block3 = nn.Sequential(*[SRGANextBlock(384) for _ in range(9)])
        self.downsample3 = nn.Sequential(
            LayerNorm(384),
            nn.Conv2d(384, 768, kernel_size=2, stride=2)
        )
        self.block4 = nn.Sequential(*[SRGANextBlock(768) for _ in range(3)])
        self.endLayer = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(768,1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
    def forward(self,x):
        x = self.stem(x)
        shortcut = x
        x = shortcut+self.block1(x)
        x = self.downsample1(x)
        shortcut = x
        x = shortcut+self.block2(x)
        x = self.downsample2(x)
        shortcut = x
        x = shortcut+self.block3(x)
        x = self.downsample3(x)
        shortcut = x
        x = shortcut+self.block4(x)
        x = self.endLayer(x)
        return x

  • 10
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值