(已发布源码)图像修复——上下文编码器以及加入全局判别器的改进(Context Encoder and Global and Local Discriminator)

5 篇文章 0 订阅
1 篇文章 1 订阅

Context Encoder ,Global and Local Discriminator

最近做了一个图像修复的project,感觉这两篇论文很好,就写点笔记嘿嘿

Context Encoder(上下文编码器)

在AE基础上modify

Context Encoder是根据AE(Auto-Encoder)修改得到。

AE在我的之前的blog里有讲解过哦,可自取~(https://blog.csdn.net/HustQbw/article/details/114670216)

我们知道,AE是一种编码器-解码器架构,用于对数据压缩、解压缩和图像去噪,但由于未在原有数据上引入约束(即新的信息),所以难以用于生成。

但这种编码器-解码器架构十分值得借鉴,先将图像用卷积层编码到一个低分辨率多通道的feature map,再进行反卷积上采样。

但有一个问题:卷积层只能联系像素点与其领域的特征,难以联系到全局的大部分乃至所有location。

channel-wise fully-connected layer(创新点)

这里就有点像分类问题了,分类问题总是先用卷积层、池化层进行特征提取,最后flatten到一维向量,利用全连接层得到分类结果。

这里全连接层所起到的作用就是整合全局所有位置的特征信息,或者说,全连接层屏蔽了特征在location上的差异(很容易理解,一维的向量就不存在空间位置的说法了),进行了全局的整合

但全连接层同样存在一个众所周知的问题:参数太太太太多了

所以作者使用了另一种全连接层————channel-wise fully-connected layer

Thus, the number of parameters in this channel-wise fully-connected layer is mn4 , compared to m2n4 parameters in a fully-connected layer (ignoring thebias term).

说白了,其实就是用1x1的conv来代替全连接层,用(batch_size,m,1,1)来代替(batchsize,m),不知道这样说会不会很抽象,但应该可以理解:)

Context Encoder组成

得到channel-wise的fully-connected layer后就相当于得到了一个整合了图片除了缺失区域的其他部分的vector,再将这个vector上采样回到原图尺寸(原文里回到了64x64,因为作者修复的是中心部分的64x64的矩形部分)

所以总结可以得到:

Context Encoder = 下采样Encoder + channel-wise fully-connected layer +上采样Decoder

作者的两个网络结构如下:

一种适用于固定大小比例、固定形状的修复(128x128到64x64):
在这里插入图片描述

一种适用于不固定大小,不固定形状的修复(原图大小到原图大小):

判别器

判别器和GAN里的判别器类似,即判断真or假,来进行生成器与判别器的对抗

判别器的损失函数和常规GAN类似,就偷懒不赘述啦嘿嘿

Unsupervised-learning

Context Encoder运用的是一种无监督的学习,即你只需要提供你的task所需的data,然后随意得到mask,输入网络的既有完整图片也有得到的mask即可实现,而不需要认为的打标

Loss Function

Reconstruction Loss重构损失

重建Loss 是一个L2 距离,主要用来规范重建过程中的行为,让重建结果更具结构且与周围的信息一致。

作者尝试了L1和L2的loss,最后发现结果差不多。

我在跑代码的时候发现一个问题,就是L2 loss会让恢复的部分趋向于此位置整个训练集像素值的平均值,所以图片会比较模糊,难以捕捉到细节,和论文里所述一致

这里的原因是,要minimize L2 loss事实上就是在最大化一个高斯分布的log_likelihood

比如要拟合猫的分布,每一种猫都会带来一个峰值,L2 loss就是用一个高斯分布去拟合这个多峰的训练集分布,所以就会趋向于训练集的均值。

Adversarial Loss对抗损失

对于一副图像中的空白区域,可能有多种填充方式(修复任务的ground truth事实上是不唯一的,甚至可以认为没有ground truth)符合整体结构以及可以与周围信息很好地接洽,我们称之为不同的模式,而对抗Loss 就是用来进行模式选择。(学习缺失区域的数据分布

Joint Loss 联合损失

通过结合重建Loss 及对抗Loss,可以让生成结果即具备相当的结构性,又符合整幅图像的语义。

Globally and Locally Consistent Image Completion

这篇论文主要针对Context Encoder的两个点进行了修改

空洞卷积

不使用channel-wise fully-connected layer,而改为使用dilated convoluton,这样就不会限制输入的尺寸了,可以适用于任意分辨率图像的修复

在这里插入图片描述

并且,原文的dilation选的很大,我认为这样也可以一定程度上实现channel-wise fully-connected layer整合全局信息的作用,因为dilation选的很大就可以连接距离较远的两个特征点

还有一点,空洞卷积和标准卷积,在相同参数量的情况下,空洞卷积可以提升更多的感受野。原文提到,若采用相同参数的标准卷积只能得到99的感受野,而使用空洞卷积可以获得307的感受野(这里我怎么算都算的是287,大家可以算一下,纠正一下我T T)

加入了Global Discriminator

除了原有的Local Discriminator,这里还加入了全局的判别器,进行双重判别,这样可以在Context Encoder的基础上有更加好的视觉质量,并且对人的五官等复杂结构也可以有比较好的恢复

可以这样理解,比如要修复的地方是鼻子的右半边,但左半边结构还在,做个镜像就得到了右半边,这就是联合了全图判别的一个好处吧(根据论文标题,就是即提升了局部的一致性consistency,也提升了全局的一致性)

网络结构:
在这里插入图片描述

整个工程源码已经发布在github(🐕)

github和gitee都已经发啦

如果登不上github可以去码云gitee来一波git clone哦

记得来个star嘿嘿

https://gitee.com/HustQBW/context-encoder

https://github.com/HustQBW/Context-Encoder

  • 3
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 9
    评论
好的,我理解您的问题。为了回答您的问题,我需要先解释一下上下文编码和GAN修复的概念。 上下文编码通常是指将原始的输入数据转化为一种更高级的表示形式,这种形式可以帮助模型更好地理解输入数据的结构和语义信息。在图像处理中,上下文编码可以通过卷积神经网络来实现。GAN修复则是一种用于图像修复的技术,它使用生成判别两个网络结构,通过训练生成来修复图像中的缺陷区域,并通过判别来判断修复结果的真实性。 下面是一份用上下文编码加GAN修复128*128带有随机掩膜区域的图像的生成判别的代码: 生成代码: ```python import torch.nn as nn import torch.nn.functional as F class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() # Encoder self.conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1) self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1) self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1) self.conv5 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1) self.conv6 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1) self.conv7 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1) self.conv8 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1) # Decoder self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1) self.deconv2 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1) self.deconv3 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1) self.deconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1) self.deconv5 = nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2, padding=1) self.deconv6 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1) self.deconv7 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1) self.deconv8 = nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1) def forward(self, x): # Encoder e1 = F.leaky_relu(self.conv1(x), negative_slope=0.2) e2 = F.leaky_relu(self.conv2(e1), negative_slope=0.2) e3 = F.leaky_relu(self.conv3(e2), negative_slope=0.2) e4 = F.leaky_relu(self.conv4(e3), negative_slope=0.2) e5 = F.leaky_relu(self.conv5(e4), negative_slope=0.2) e6 = F.leaky_relu(self.conv6(e5), negative_slope=0.2) e7 = F.leaky_relu(self.conv7(e6), negative_slope=0.2) e8 = F.leaky_relu(self.conv8(e7), negative_slope=0.2) # Decoder d1 = F.relu(self.deconv1(e8)) d1 = F.dropout(d1, 0.5, training=True) d2 = F.relu(self.deconv2(torch.cat([d1, e7], dim=1))) d2 = F.dropout(d2, 0.5, training=True) d3 = F.relu(self.deconv3(torch.cat([d2, e6], dim=1))) d3 = F.dropout(d3, 0.5, training=True) d4 = F.relu(self.deconv4(torch.cat([d3, e5], dim=1))) d4 = F.dropout(d4, 0.5, training=True) d5 = F.relu(self.deconv5(torch.cat([d4, e4], dim=1))) d6 = F.relu(self.deconv6(torch.cat([d5, e3], dim=1))) d7 = F.relu(self.deconv7(torch.cat([d6, e2], dim=1))) d8 = torch.tanh(self.deconv8(torch.cat([d7, e1], dim=1))) return d8 ``` 判别代码: ```python import torch.nn as nn import torch.nn.functional as F class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1) self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1) self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1) self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1) self.conv5 = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1) def forward(self, x): x = F.leaky_relu(self.conv1(x), negative_slope=0.2) x = F.leaky_relu(self.conv2(x), negative_slope=0.2) x = F.leaky_relu(self.conv3(x), negative_slope=0.2) x = F.leaky_relu(self.conv4(x), negative_slope=0.2) x = torch.sigmoid(self.conv5(x)) return x ``` 注意:这只是一个代码示例,并不是一个完整的可运行代码。您需要使用这个示例代码作为起点,并根据您的数据集和实际需求进行适当的修改。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值