ESRGAN网络结构
一.生成器
生成网络的作用是输入一张低分辨率图片,生成高分辨率图片
网络共由几部分组成:
1.浅层特征抽取网络,提取浅层特征。低分辨率图像进入后会经过一个卷积+RELU函数,将输入通道数调整为64
2.RRDB(Residual in Residual Dense Block)网络结构,包含然N个RDB(Residual Dense Block)密集残差块和一个残差边,每个RDB都包含5个卷积+RELU
3.上采样网络,然后进入上采样部分,在经过两次上采样后,原图的高宽变为原来的4倍,并且实现分辨率的提升。
密集残差快的网络结构如下:
RRDB(Residual in Residual Dense Block)采用了两层残差结构,RRDB结构由一个大的残差结构构成,主干部分由3个 RDB(Residual Dense Block)密集残差块构成,将主干网络的输出与残差边叠加。
在程序中,每个RDB块都有5个卷积,然后通过torch.cat函数,将卷积的通道数相叠加,所以卷积的通道数由num_feat,叠加变成num_feat + 4 * num_grow_ch,利用最后一个卷积将通道数调整为num_feat ,通过x5 * 0.2 + x构建残差边。
(RDB)Residual Dense Block结构相当于将Residual block (ResBlock)与Denseblock密集块相结合,通过密集连通卷积层提取丰富的局部特征,从先前RDB的状态直接连接到当前RDB的所有层,然后利用RDB的局部特征融合自适应地从先前和当前的局部特征中学习更有效的特征,使训练更加稳定。
相比较于传统的DenseBlock,网络结构去掉BN层,BN层在训练期间使用批次的均值和方差对特征进行归一化,在测试期间使用整个训练数据集的估计均值和方差。当训练和测试数据集的统计数据差异很大时,BN层往往引入不适的伪影,限制了泛化能力。我们以经验观察到,BN层有可能当网络深和在GAN网络下训练时带来伪影。这些伪影偶尔出现在迭代和不同设置之间,违反了稳定性能超过训练的需求。因此,我们为了训练稳定和一致性去除了BN层。此外,由于批处理标准化层消耗的内存量与前面的卷积层相同,因此也充分减少了图形处理器的内存使用。总之,去除BN层有助于提高泛化能力,减少计算复杂度和内存使用。
class ResidualDenseBlock(nn.Module):
"""Residual Dense Block.
Used in RRDB block in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))#这里指通道数的叠加,所以每次叠加结束,通道数都增加num_grow_ch
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
# Emperically, we use 0.2 to scale the residual for better performance
return x5 * 0.2 + x# 这里指残差边,0.2指残差边的参数
网络由多个basic block构成,每个basicblock 都由RRDB组成,其中RRDB即在级联3个RDB的基础上添加残差边
程序如下:
class RRDB(nn.Module):
"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
# Emperically, we use 0.2 to scale the residual for better performance
return out * 0.2 + x#构建残差边
生产器的整体结构为:
class RRDBNet(nn.Module):
"""Networks consisting of Residual in Residual Dense Block, which is used
in ESRGAN.
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
We extend ESRGAN for scale x2 and scale x1.
Note: This is one option for scale 1, scale 2 in RRDBNet.
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64
num_block (int): Block number in the trunk network. Defaults: 23
num_grow_ch (int): Channels for each growth. Default: 32.
"""
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
## scare 指的是放大的倍数
super(RRDBNet, self).__init__()
self.scale = scale
if scale == 2:
num_in_ch = num_in_ch * 4
# 图片变1/2,通道数为原先的4倍 所以num_feat必须为4的倍数 后面再放大4 倍
elif scale == 1:
num_in_ch = num_in_ch * 16
# 图片变1/4,通道数为原先的16倍 所以num_feat必须为16的倍数 后面再放大4 倍
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)###注意make_layer指将多个相同的块叠加,论文中命名为basic_bllok num_block为块的个数,在ESRGAN中设置numblock 个数为23
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
if self.scale == 2:
feat = pixel_unshuffle(x, scale=2)
##下采样,相当于将通道时X2
elif self.scale == 1:
feat = pixel_unshuffle(x, scale=4)
# 相当于通道数X4
else:
feat = x
# 由于设置scale为4,所以这里不执行此操作
feat = self.conv_first(feat)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
# upsample
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))#将图片两次放大两倍,采样方式采用nearest最近值采样
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
二.鉴别器
采用相对鉴别器,相对论鉴别器试图预测真实图像xf相对比假图像xf更真实的概率
具体来说,我们用相对论平均鉴别器RaD代替标准鉴别器,
其中σ是sigmoid函数 ,
C(x)是鉴别器的输出,
E(x)指对一个小批次的真实图片or假图片取平均值
所以鉴别器对抗损失定义为:
生成器对抗损失定义为:
代码中使用BCE函数计算生成损失和对抗损失,用于测量目标和输出之间的二进制交叉熵
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None,
reduce=None, reduction='mean', pos_weight=None):
其中C( x)为鉴别器网络结构:输入为一张图片,输出为一个数,其输入尺寸必须为128128 其中经过10词卷积将图片不断下采样到44 通道变为num_feat*8,通过view函数将卷积结果展开成一维向量,再经过两个全连接层,输出一个数。
class VGGStyleDiscriminator128(nn.Module):
"""VGG style discriminator with input size 128 x 128.
It is used to train SRGAN and ESRGAN.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features.
Default: 64.
"""
def __init__(self, num_in_ch, num_feat):
super(VGGStyleDiscriminator128, self).__init__()
#convn_0为3*3的卷积,convn_1为4*4的卷积,并且将图片下采样到1/2,通道数变为原来的2倍
self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)
self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)
self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)
self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100) ### 全连接 in_features指的是输入的二维张量的大小out_features指的是输出的二维张量的大小
self.linear2 = nn.Linear(100, 1)
# activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
assert x.size(2) == 128 and x.size(3) == 128, (f'Input spatial size must be 128x128, but received {x.size()}.')
feat = self.lrelu(self.conv0_0(x))
feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: (64, 64)channel:num_feat
feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: (32, 32)channel:num_feat*2
feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: (16, 16)channel:num_feat*4
feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: (8, 8)channel:num_feat*8
feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: (4, 4)channel:num_feat*8
feat = feat.view(feat.size(0), -1) # 将卷积展成一行,输出为num_feat * 8 * 4 * 4的向量
feat = self.lrelu(self.linear1(feat)) # 卷积结果进行全连接并进行激活
out = self.linear2(feat) # 对结果进行全连接并不进行激活
return out
三.损失函数
损失函数有三部分:`
pixel_opt: #内容损失
type: L1Loss
loss_weight: !!float 1e-2
reduction: mean
perceptual_opt: #感知损失
type: PerceptualLoss
layer_weights:
'conv5_4': 1 # before relu
vgg_type: vgg19
use_input_norm: true
range_norm: false
perceptual_weight: 1.0
style_weight: 0
criterion: l1
gan_opt: #对抗损失
type: GANLoss
gan_type: vanilla
real_label_val: 1.0
fake_label_val: 0.0
loss_weight: !!float 5e-3