GaitGAN学习笔记

论文

1 Abstract

image-20220228142803585

​ 本文利用GAN模型去生成不同视角下的步态剪影图,并且解决背包等协变量的问题。

​ 创新之处在与,相比于传统的GAN模型,其使用了两个不同的判别器:一个是传统的真伪辨别器,另一个用于维持步态剪影中含有的身份信息。

2 Proposed Method

最大的挑战:生成的步态剪影图不仅要看上去是真的,还需要保持人体可辨别的身份信息。

2.1 GEI输入

​ 插播关于GEI的制作步骤:

​ ① 将二值轮廓图像缩放为统一的标准尺寸,即归一化

image-20220228144659526

​ ② 将一个步态序列的归一化图片加权求和生成一张图片(0-255范围)

image-20220228144750889

2.2 GAN

​ GAN中有两个角色:**生成器G和判别器D,他们相互对抗共同训练。**其中,G类似生成假钞的罪犯,D类似分辨假钞的警察,G和D相互对抗,G生成的假币更加逼真,D的判别能力更加强大。训练的最终目标是使得D无法再判别出G生成数据的真伪(即达到纳什均衡)。

​ 生成器G:接受一个噪声向量,生成数据

​ 判别器D:接受真实数据,或者生成器G生成的数据,对他们进行二分类(分辨出真假)

​ 训练过程:

image-20220228145927080

​ 其中需要注意的是,判别器D的生成是0-1之间的数据。例如D判别一个数据是100%真实数据,它将输出1。

​ · 训练判别器:在最初的iteration中,对D进行反复训练,使得D(真实数据) >> 1, D(G(噪声数据)) >> 0从而logD(真实数据) >> 0,log(1-D(G(噪声数据)))>> 0,让loss最大(loss的最大值逼近于0)。

​ · 训练生成器:目的是让判别器出错,让 D(G(噪声数据)) >> 1,从而log(1-D(G(噪声数据)))>> 负无穷小,让loss最小(无穷小)。

2.3 引入PixelGAN原理

​ GAN的输入也可以是图片而不是噪声向量。例如在PixelDTGAN中,它可以识别输入图片与目标图片之间的像素级转换,同时可以建立起输入域与目标域之间的语义含义,从而保证生成的图片看上去真实,同时维持其语义含义。

image-20220228155429070

image-20220228155452868

​ 对于Real/fake D,其作用就是用于尽最大可能地让生成的数据看上去是真实的,它的输入是生成的数据以及真实的数据,产生一个概率值用于判断图像是由生成器生成的还是真实数据。例如:对于真实图片,label为1;对于生成器生成的图片,label为0,做二分类。

​ 其损失函数为(最小化):

image-20220228160940226

​ 对于Domain D,起作用就是用于保持语义信息,它的输入是一对源图像的目标图像(一个相关的和一个不相关的),产生一个概率值用于判断这对图像是否关联。例如:对于一对源图片与目标图片,若相关联label为1,若不关联label为0,做二分类。

​ 其损失函数为(最小化):

image-20220228161816325

​ 如果目标图像是与源图像相关的图像,让D(I_s, I) >> 1;如果目标图像是生成的图像,或者是与源图像无关的图像,让D(I_s, I) >> 0

2.4 GaitGAN

与PixelDTGAN的想法类似,GaitGAN将所有视角以及带有协变量的GEI当做源图像,把正常的90°的GEI当做目标图像。

image-20220228162333853

​ 对于编码解码器,利用常规的CNN:

image-20220228162642805

​ 对于Real/fake D,其用于预测生成的图片是否为真实的图片。如果生成90°的NM图片,输出为1,否则为0。

image-20220228162617612

​ 对于Domain D(文中命名为identification discriminator),用于预测图片是否相关,如果目标图与源图片是同一个人,输出1;如果不是同一个人或者是生成出来的图片,则输出0。

image-20220228163148860

代码

1 主干网络

image-20220301104340904

​ 在解码器中运用到了反卷积,此处做补充:

image-20220301105158481

class NetG(nn.Module):
    def __init__(self, nc=3, ngf=96):
        super(NetG, self).__init__()
        self.converter = nn.Sequential(
            nn.Conv2d(nc, ngf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ngf, ngf*2, kernel_size=4, stride=2, padding=1,
                      bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ngf*2, ngf*4, kernel_size=4, stride=2, padding=1,
                      bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ngf*4, ngf*8, kernel_size=4, stride=2, padding=1,
                      bias=False),
            nn.BatchNorm2d(ngf*8),
            nn.LeakyReLU(0.2, True),
			
            # 反卷积
            nn.ConvTranspose2d(ngf*8, ngf*4,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
			
            # 反卷积
            nn.ConvTranspose2d(ngf*4, ngf*2, kernel_size=4,
                               stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),

            # 反卷积
            nn.ConvTranspose2d(ngf*2, ngf, kernel_size=4,
                               stride=2, padding=1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            # 反卷积
            nn.ConvTranspose2d(ngf, nc, kernel_size=4, stride=2, padding=1,
                               bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.converter(x)
        return x


class NetD(nn.Module):
    def __init__(self, nc=3, ndf=96):
        super(NetD, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1,
                      bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1,
                      bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1,
                      bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ndf*8, 1, kernel_size=4, stride=4, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.discriminator(x)
        return x.view(-1, 1)


'''
domain discriminator
'''
class NetA(nn.Module):
    def __init__(self, nc=3, ndf=96):
        super(NetA, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Conv2d(nc*2, ndf, kernel_size=4, stride=2, padding=1,
                      bias=False),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1,
                      bias=False),
            nn.BatchNorm2d(ndf*2),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1,
                      bias=False),
            nn.BatchNorm2d(ndf*4),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1,
                      bias=False),
            nn.BatchNorm2d(ndf*8),
            nn.LeakyReLU(0.2, True),

            nn.Conv2d(ndf*8, 1, kernel_size=4, stride=4, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.discriminator(x)
        return x.view(-1, 1)

2 迭代过程

'''更新生成器'''
lossG = 0
optimG.zero_grad()
fake = netg(img)
output = netd(fake)  # 生成的图

label.fill_(real_label)  # 骗过D,是真标签
lossGD = F.binary_cross_entropy(output, label)
lossG += lossGD.item()
lossGD.backward(retain_graph=True)

faked = th.cat((img, fake), 1)
output = neta(faked)  # 生成的图
label.fill_(real_label)  # 骗过A,是真标签
lossGA = F.binary_cross_entropy(output, label)
lossG += lossGA.item()
lossGA.backward()
optimG.step()  # 一起更新D和A,骗过他们
'''更新fake/real 辨别器'''
label.fill_(real_label)  # 真标签
output = netd(ass_label)  # 相关联的图,是真图
lossD_real1 = F.binary_cross_entropy(output, label)
lossD += lossD_real1.item()
lossD_real1.backward()

label.fill_(real_label)  # 真标签
output1 = netd(noass_label)  # 不相关的图,但也是真的图
lossD_real2 = F.binary_cross_entropy(output1, label)
lossD == lossD_real2.item()
lossD_real2.backward()

label.fill_(fake_label)  # 假标签
fake = netg(img).detach()  # 生成的图
output2 = netd(fake)

lossD_fake = F.binary_cross_entropy(output2, label)
lossD += lossD_fake.item()
lossD_fake.backward()

optimD.step()  # 更新D
'''更新Domain 辨别器'''
label.fill_(real_label)  # 真标签
output1 = neta(assd)  # 相关的图
lossA_real1 = F.binary_cross_entropy(output1, label)
lossA += lossA_real1.item()
lossA_real1.backward()

label.fill_(fake_label)  # 假标签
output = neta(noassd)  # 不相关的图
lossA_real2 = F.binary_cross_entropy(output, label)
lossA += lossA_real2.item()
lossA_real2.backward()

label.fill_(fake_label)  # 假标签
output = neta(faked)  # 生成的图
lossA_fake = F.binary_cross_entropy(output, label)
lossA += lossA_fake.item()
lossA_fake.backward()
optimA.step()  # 更新A
  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值