超分之ESRGAN官方代码解读

改进1:生成器RRDBNet_arch.py.py

  • 引入了没有BN层的Residual-in-Residual Dense Block(RRDB)作为基本网络构建单元
    在这里插入图片描述
    在这里插入图片描述

1.1 RDB(Residual Dense Block)

class ResidualDenseBlock(nn.Module):
    """Residual Dense Block.

    Used in RRDB block in ESRGAN.

    Args:
        num_feat (int): Channel number of intermediate features. (中间特征的通道数)
        num_grow_ch (int): Channels for each growth. (每次增长的通道数)
    """

    def __init__(self, num_feat=64, num_grow_ch=32):
        super(ResidualDenseBlock, self).__init__()
        # 输入通道数,输出通道数,卷积核,步长,填充
        self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
        self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # # initialization (初始化每个卷积层的权重参数)
        # default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        # Empirically, we use 0.2 to scale the residual for better performance
        # 0.2 是残差缩放的超参数
        return x5 * 0.2 + x

查看RDB中具体的卷积操作

RDB = ResidualDenseBlock()
RDB

在这里插入图片描述
测试RDB

X =torch.rand(1, 64, 256, 256)
Y =RDB(X)
Y.shape

在这里插入图片描述

1.2 RRDB(Residual in Residual Dense Block)

class RRDB(nn.Module):
    """Residual in Residual Dense Block.

    Used in RRDB-Net in ESRGAN.

    Args:
        num_feat (int): Channel number of intermediate features.
        num_grow_ch (int): Channels for each growth.
    """

    def __init__(self, num_feat, num_grow_ch=32):
        super(RRDB, self).__init__()
        self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)

    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        # Empirically, we use 0.2 to scale the residual for better performance
        return out * 0.2 + x

查看RRDB中具体的卷积操作

RRDB = RRDB(64)
RRDB

在这里插入图片描述
测试RDB

X =torch.rand(1, 64, 256, 256)
Y =RRDB(X)
Y.shape

在这里插入图片描述

1.3 RRDBNet(Networks consisting of Residual in Residual Dense Block)

  • make_layer(): 顺序生成指定数量的基本块
  • pixel_unshuffle():像素重组的逆过程 (t通道数增加,长宽缩小)
def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.

    Args:
        basic_block (nn.module): nn.module class for basic block. 基本块
        num_basic_block (int): number of blocks. 基本块的个数

    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)


def pixel_unshuffle(x, scale):
    """ Pixel unshuffle.

    [n, c, w, h]  ---> [n, c*scale*scale, w/scale, h/scale]
    Args:
        x (Tensor): Input feature with shape (b, c, hh, hw).
        scale (int): Downsample ratio.

    Returns:
        Tensor: the pixel unshuffled feature.
    """
    b, c, hh, hw = x.size()
    out_channel = c * (scale**2)  #
    assert hh % scale == 0 and hw % scale == 0
    h = hh // scale
    w = hw // scale
    x_view = x.view(b, c, h, scale, w, scale)  # [b, c, h/scale, scale, w/scale, scale]
    return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
x = torch.rand(1, 64, 256, 256)
feat = pixel_unshuffle(x, scale=2)
print(feat.shape)  # [1, 256, 128, 128]

在这里插入图片描述

class RRDBNet(nn.Module):
    """Networks consisting of Residual in Residual Dense Block, which is used
    in ESRGAN.

    ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.

    We extend ESRGAN for scale x2 and scale x1.
    Note: This is one option for scale 1, scale 2 in RRDBNet.
    We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
    and enlarge the channel size before feeding inputs into the main ESRGAN architecture.

    Args:
        num_in_ch (int): Channel number of inputs. (输入通道数)
        num_out_ch (int): Channel number of outputs. (输出通道数)
        num_feat (int): Channel number of intermediate features. (中间特征的通道数)
            Default: 64
        num_block (int): Block number in the trunk network. Defaults: 23 (RDB块的个数)
        num_grow_ch (int): Channels for each growth. Default: 32. (增长的通道数)
    """

    def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
        super(RRDBNet, self).__init__()

        self.scale = scale
        # 默认缩放因子是4,输入通道数为num_in_ch
        if scale == 2:
            # 如果缩放因子为2,输入通道数为2*2倍num_in_ch
            num_in_ch = num_in_ch * 4
        elif scale == 1:
            # 如果缩放因子为1,则输入通道数为4*4倍num_in_ch
            num_in_ch = num_in_ch * 16

        # 浅层特征提取层:1个3×3,步长为1,填充为1的卷积层 [n, c, h, w] -->[n, 64, h, w]
        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)

        # 深层特征提取层:23个RRDB + 1个3×3,步长为1,填充为1的卷积层[n, c+64,h, w]-->[n, 64, h, w]
        self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
        self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)

        # 上采样重建层:两个上采样操作(卷积+插值) + 2个3×3,步长为1,填充为1的卷积层
        # upsample
        self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        if self.scale == 2:
            # 如果缩放因子为2,长宽变为原来的1/2倍,通道数增加为原来的2倍
            feat = pixel_unshuffle(x, scale=2)
        elif self.scale == 1:
            # 如果缩放因子为1,长宽变为原来的1/4倍,通道数增加为原来的4倍
            feat = pixel_unshuffle(x, scale=4)
        else:
            # 默认情况下,缩放因子为4倍
            feat = x

        # 浅层特征提取层
        feat = self.conv_first(feat)

        # 深层特征提取层
        body_feat = self.conv_body(self.body(feat))

        # 残差连接:浅层输出+深层输出
        feat = feat + body_feat

        # 上采样重建层
        # upsample
        feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
        feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.conv_hr(feat)))
        return out

查看RRDBNet的网络结构

RRDBNet = RRDBNet(64, 64)
RRDBNet

在这里插入图片描述
在这里插入图片描述
测试RRDBNet

X =torch.rand(1, 64, 256, 256)
Y =RRDBNet(X)
print(Y.shape)

在这里插入图片描述

改进2:相对判别器

在这里插入图片描述

# gan loss (relativistic gan)  对抗损失
# 原始图像的判别得分
real_d_pred = self.net_d(self.gt).detach()
# 生成图像的判别得分
fake_g_pred = self.net_d(self.output)
# 真实图像的判别分数=原始图像的判别分数-生成图像的判别得分的平均值
l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False)
# 生成图像的判别分数=生成图像的判别分数-真实图像的判别得分的平均值
l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False)
# 生成器网络损失
l_g_gan = (l_g_real + l_g_fake) / 2

改进3:使用激活之前的VGG特征计算损失函数

from torch import nn as nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm


class VGGStyleDiscriminator(nn.Module):
    """VGG style discriminator with input size 128 x 128 or 256 x 256.

    It is used to train SRGAN, ESRGAN, and VideoGAN.

    Args:
        num_in_ch (int): Channel number of inputs. Default: 3. (输入数据的通道数,默认=3)
        num_feat (int): Channel number of base intermediate features.Default: 64. (中间特征的通道数,默认=64)
    """

    def __init__(self, num_in_ch, num_feat, input_size=128):
        super(VGGStyleDiscriminator, self).__init__()
        self.input_size = input_size
        assert self.input_size == 128 or self.input_size == 256, (f'input size must be 128 or 256, but received {input_size}')

        # convx_0:卷积核为3*3,图像尺寸大小不变:(n-k+2p)/s+1 = (n-3+2*1)/1+1 = n
        # convx_1: 卷积核为4*4,步长为2,图像尺寸大小减半:(n-k+2p)/s+1 = (n-4+2*1)/2+1= n/2
        self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
        self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
        self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)

        self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
        self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
        self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
        self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)

        self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
        self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
        self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
        self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)

        self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
        self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
        self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
        self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)

        self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
        self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
        self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
        self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)

        if self.input_size == 256:
            self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
            self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
            self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
            self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True)

        self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
        self.linear2 = nn.Linear(100, 1)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')

        feat = self.lrelu(self.conv0_0(x))
        feat = self.lrelu(self.bn0_1(self.conv0_1(feat)))  # output spatial size: /2  128/2=64

        feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
        feat = self.lrelu(self.bn1_1(self.conv1_1(feat)))  # output spatial size: /4  64/2=32

        feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
        feat = self.lrelu(self.bn2_1(self.conv2_1(feat)))  # output spatial size: /8  32/2=16

        feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
        feat = self.lrelu(self.bn3_1(self.conv3_1(feat)))  # output spatial size: /16  16/2=8

        feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
        feat = self.lrelu(self.bn4_1(self.conv4_1(feat)))  # output spatial size: /32  8/2=4

        if self.input_size == 256:
            feat = self.lrelu(self.bn5_0(self.conv5_0(feat)))
            feat = self.lrelu(self.bn5_1(self.conv5_1(feat)))  # output spatial size: / 64

        # spatial size: (4, 4)
        feat = feat.view(feat.size(0), -1)  # 将张量展成一行,输出为num_feat * 8 * 4 * 4的向量
        feat = self.lrelu(self.linear1(feat))  # 全连接+ReLU激活
        out = self.linear2(feat)  # 只进行全连接,不使用激活函数
        return out
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
以下是对tph-yolov5增加超分网络的代码示例: ```python import torch import torch.nn as nn import torch.nn.functional as F from torch.cuda.amp import autocast from models.common import Conv from models.yolo import Detect from models.super_resolution import SuperResolutionNet class TPH(nn.Module): def __init__(self, num_classes, input_channels=3, super_res_scale=4): super(TPH, self).__init__() self.num_classes = num_classes self.input_channels = input_channels self.super_res_scale = super_res_scale # Super Resolution Network self.super_res = SuperResolutionNet(scale=self.super_res_scale) # Backbone self.backbone = nn.Sequential( Conv(self.input_channels, 32, 3, 1), nn.MaxPool2d(2, 2), Conv(32, 64, 3, 1), nn.MaxPool2d(2, 2), Conv(64, 128, 3, 1), Conv(128, 64, 1, 1), Conv(64, 128, 3, 1), nn.MaxPool2d(2, 2), Conv(128, 256, 3, 1), Conv(256, 128, 1, 1), Conv(128, 256, 3, 1), nn.MaxPool2d(2, 2), Conv(256, 512, 3, 1), Conv(512, 256, 1, 1), Conv(256, 512, 3, 1), Conv(512, 256, 1, 1), Conv(256, 512, 3, 1), ) # Neck self.neck = nn.Sequential( Conv(512, 256, 1, 1), Conv(256, 512, 3, 1), Conv(512, 256, 1, 1), Conv(256, 512, 3, 1), Conv(512, 256, 1, 1), ) # Head self.head = nn.Sequential( Conv(256, 512, 3, 1), nn.Conv2d(512, (self.num_classes + 5) * 3, 1, 1, bias=True), Detect(num_classes=self.num_classes) ) @autocast() def forward(self, x): # Super Resolution Network x = self.super_res(x) # Backbone x = self.backbone(x) # Neck x = self.neck(x) # Head x = self.head(x) return x ``` 在这个示例中,我们在原始的TPH模型中添加了一个超分辨率网络。该网络将输入图像放大一定倍数,并将其用作TPH模型的输入。这可以提高模型对细节的感知能力,从而提高检测精度。 注意,这只是一个示例代码,并不是在所有情况下都适用的通用代码。根据您的具体需求,您可能需要修改或完全重写代码
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值