starGAN 2023/3/3

starGAN-v1 &v2

模型主要结构

starGAN-v1

在这里插入图片描述

v1中的domain

在这里插入图片描述

以RaFD为例,v1中以one-hot编码来表达不同的domain信息,例如10010,他的含义为转换成生气和伤心的特征信息。

在这里插入图片描述

starGAN-v2

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6KAv3JiM-1677477710458)(C:\Users\mcpau\AppData\Roaming\Typora\typora-user-images\image-20230224214205088.png)]

1.生成器:与本质的GAN网络的生成器一样,G网络负责将输入的图像以想要的style风格进行合成图像的生成

2.Mapping network:这里的style就跟v1完全不一样,v1中的style与domain是一个意义也就是表达想要合成的特征信息,而v2中的style则是一串特征编码而不是onehot这种没有学习意义的编码,mapping network将latent code通过一个MLP分支出每个domain所对应的style code。本质就是在随机高斯噪声中学习出不同的stylecode来生成不同的图像。

3.Style encoder:输入一张图片,通过E网络可以学习到图片中的风格和所有的domain

4.判别器:尽可能的学习来判断输入的图片是真实的还是生成的,判别器需要不停的将判断结果告诉生成器,生成器就可以通过学习来欺骗判别器。

从图中可以看到styleEncoder和discriminator的特征提取网络是一个网络。在隐藏层个数和参数个数中也可以看出生成器所需要的计算资源是最多的

代码detail

  • ResBlock结构

在这里插入图片描述

class ResBlk(nn.Module):
    def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
                 normalize=False, downsample=False):
        super().__init__()
        self.actv = actv
        self.normalize = normalize
        self.downsample = downsample
        self.learned_sc = dim_in != dim_out
        self._build_weights(dim_in, dim_out)
    def _build_weights(self, dim_in, dim_out):
        self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
        if self.normalize:
            self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
            self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
        if self.learned_sc:
            self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)

    def _shortcut(self, x):
        if self.learned_sc:
            x = self.conv1x1(x)
        if self.downsample:
            x = F.avg_pool2d(x, 2)
        return x

    def _residual(self, x):
        if self.normalize:
            x = self.norm1(x)
        x = self.actv(x)
        x = self.conv1(x)
        if self.downsample:
            x = F.avg_pool2d(x, 2)
        if self.normalize:
            x = self.norm2(x)
        x = self.actv(x)
        x = self.conv2(x)
        return x
  • AdainResBlock

在这里插入图片描述

class AdainResBlk(nn.Module):
    def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=0,
                 actv=nn.LeakyReLU(0.2), upsample=False):
        super().__init__()
        self.w_hpf = w_hpf
        self.actv = actv
        self.upsample = upsample
        self.learned_sc = dim_in != dim_out
        self._build_weights(dim_in, dim_out, style_dim)

    def _build_weights(self, dim_in, dim_out, style_dim=64):
        self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
        self.norm1 = AdaIN(style_dim, dim_in)
        self.norm2 = AdaIN(style_dim, dim_out)
        if self.learned_sc:
            self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)

    def _shortcut(self, x):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        if self.learned_sc:
            x = self.conv1x1(x)
        return x

    def _residual(self, x, s):
        x = self.norm1(x, s)
        x = self.actv(x)
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = self.conv1(x)
        x = self.norm2(x, s)
        x = self.actv(x)
        x = self.conv2(x)
        return x

    def forward(self, x, s):
        out = self._residual(x, s)
        if self.w_hpf == 0:
            out = (out + self._shortcut(x)) / math.sqrt(2)
        return out

Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization https://arxiv.org/abs/1703.06868

  • 生成器

    在这里插入图片描述
    class Generator(nn.Module):
        def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, w_hpf=1):
            super().__init__()
            dim_in = 2**14 // img_size
            self.img_size = img_size
            self.from_rgb = nn.Conv2d(3, dim_in, 3, 1, 1)
            self.encode = nn.ModuleList()
            self.decode = nn.ModuleList()
            self.to_rgb = nn.Sequential(
                nn.InstanceNorm2d(dim_in, affine=True),
                nn.LeakyReLU(0.2),
                nn.Conv2d(dim_in, 3, 1, 1, 0))
    
            # down/up-sampling blocks
            repeat_num = int(np.log2(img_size)) - 4   ##8-4=4
            if w_hpf > 0:
                repeat_num += 1   #4+1=5
            for _ in range(repeat_num):
                dim_out = min(dim_in*2, max_conv_dim)  ##向下采样的维度,最高是512,每次通道翻倍
                self.encode.append(
                    ResBlk(dim_in, dim_out, normalize=True, downsample=True)) ##带IN归一化和avgpoooling的下采样残差快
                self.decode.insert(
                    0, AdainResBlk(dim_out, dim_in, style_dim,##  带上采样的adainResblock
                                   w_hpf=w_hpf, upsample=True))  # stack-like
                dim_in = dim_out
    
            # bottleneck blocks
            for _ in range(2):
                self.encode.append(
                    ResBlk(dim_out, dim_out, normalize=True))
                self.decode.insert(
                    0, AdainResBlk(dim_out, dim_out, style_dim, w_hpf=w_hpf))
    
            if w_hpf > 0:
                device = torch.device(
                    'cuda' if torch.cuda.is_available() else 'cpu')
                self.hpf = HighPass(w_hpf, device)
    
        def forward(self, x, s, masks=None):
            x = self.from_rgb(x)
            cache = {}
            for block in self.encode:
                if (masks is not None) and (x.size(2) in [32, 64, 128]):
                    cache[x.size(2)] = x   ##只对前三次特征提取的特征图进行热力分析
                x = block(x)
            for block in self.decode:
                x = block(x, s)
                if (masks is not None) and (x.size(2) in [32, 64, 128]):  ##masks就是原图的热力图
                    mask = masks[0] if x.size(2) in [32] else masks[1]
                    mask = F.interpolate(mask, size=x.size(2), mode='bilinear')  ##将热力图升维到原图的维度
                    x = x + self.hpf(mask * cache[x.size(2)])  ##意义就是将原图与热力关注点(可以理解为图片的核心注意力)一起放入highpass进行边缘检测
            return self.to_rgb(x)  ##将维度转换为3维R,G,B
        
     
    
  • mapping network

    class MappingNetwork(nn.Module):
        def __init__(self, latent_dim=16, style_dim=64, num_domains=2):
            super().__init__()
            layers = []
            ####将16维的latent升维到512
            layers += [nn.Linear(latent_dim, 512)]
            layers += [nn.ReLU()]
            #####三次shared 512的自编码,特征压缩的过程。domain之间的信息是共享的
            for _ in range(3):
                layers += [nn.Linear(512, 512)]
                layers += [nn.ReLU()]
            self.shared = nn.Sequential(*layers)
            self.unshared = nn.ModuleList()
            ##########   K次(K是num_domains)   最后从512变成stylecode的64维,就是个还原的过程
            for _ in range(num_domains):
                self.unshared += [nn.Sequential(nn.Linear(512, 512),
                                                nn.ReLU(),
                                                nn.Linear(512, 512),
                                                nn.ReLU(),
                                                nn.Linear(512, 512),
                                                nn.ReLU(),
                                                nn.Linear(512, style_dim))]
        def forward(self, z, y):
            h = self.shared(z)
            out = []
            for layer in self.unshared:
                out += [layer(h)]
            out = torch.stack(out, dim=1)  # (batch, num_domains, style_dim)   ##所有domain的stylecode  B*K*64
            idx = torch.LongTensor(range(y.size(0))).to(y.device)  ##获取当前domain所在out的第一个序号
            s = out[idx, y]  # (batch, style_dim)      
            return s
    
  • style encoder

在这里插入图片描述

class StyleEncoder(nn.Module):
    def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512):
        super().__init__()
        dim_in = 2**14 // img_size  ##64
        blocks = []
        blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]  ##3channel->64channel,3*3 kernal,stride = 1,padding=1

        repeat_num = int(np.log2(img_size)) - 2  ## 8-2=6次resblock
        for _ in range(repeat_num):
            dim_out = min(dim_in*2, max_conv_dim)   ##每次通道数翻倍,但最大就是512channel
            blocks += [ResBlk(dim_in, dim_out, downsample=True)]  ##下采样是stride=2的avgpooling
            dim_in = dim_out  ##再将输出维度变成输入维度,循环特征压缩

        blocks += [nn.LeakyReLU(0.2)]
        blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]  ##512channel,4*4卷积,步长为1,padding为0
        blocks += [nn.LeakyReLU(0.2)]
        self.shared = nn.Sequential(*blocks)

        self.unshared = nn.ModuleList()
        for _ in range(num_domains):
            self.unshared += [nn.Linear(dim_out, style_dim)]  ##还原成stylecode的64维尺寸

  • 判别器

    class StyleEncoder(nn.Module):
        def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512):
            super().__init__()
            dim_in = 2**14 // img_size  ##64
            blocks = []
            blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]  ##3channel->64channel,3*3 kernal,stride = 1,padding=1
    
            repeat_num = int(np.log2(img_size)) - 2  ## 8-2=6次resblock
            for _ in range(repeat_num):
                dim_out = min(dim_in*2, max_conv_dim)   ##每次通道数翻倍,但最大就是512channel
                blocks += [ResBlk(dim_in, dim_out, downsample=True)]  ##下采样是stride=2的avgpooling
                dim_in = dim_out  ##再将输出维度变成输入维度,循环特征压缩
    
            blocks += [nn.LeakyReLU(0.2)]
            blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]  ##512channel,4*4卷积,步长为1,padding为0
            blocks += [nn.LeakyReLU(0.2)]
            self.shared = nn.Sequential(*blocks)
    
            self.unshared = nn.ModuleList()
            for _ in range(num_domains):   ##(num_domains,style_dim)   
                self.unshared += [nn.Linear(dim_out, style_dim)]  ##还原成stylecode的64维尺寸
    class Discriminator(nn.Module):
        def __init__(self, img_size=256, num_domains=2, max_conv_dim=512):
            super().__init__()
            dim_in = 2**14 // img_size
            blocks = []
            blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
    
            repeat_num = int(np.log2(img_size)) - 2
            for _ in range(repeat_num):
                dim_out = min(dim_in*2, max_conv_dim)
                blocks += [ResBlk(dim_in, dim_out, downsample=True)]
                dim_in = dim_out
    
            blocks += [nn.LeakyReLU(0.2)]
            blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]
            blocks += [nn.LeakyReLU(0.2)]
            blocks += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)]  ##最后通过一个stride=1的1*1卷积,因为判别器不需要最后生成64维的stylecode,它只需要生成0或者1(真或者假)也就是一维数据,有很多个domain,所以最后输出维度就是num_domains*1
            self.main = nn.Sequential(*blocks)
    

starGAN-v2的loss函数

  • 生成对抗的loss

L a d v = E x , y [ l o g D y ( x ) ] + E x , y ~ , z [ l o g ( 1 − D y ~ ( G ( x , s ~ ) ) ) ] (1) \mathcal{L_adv}= \mathbb E_{x,y}[logD_y(x)]+\mathbb E_{x,\tilde{y},z}[log(1-D_{\tilde{y}}(G(x,\tilde{s})))] \tag{1} Ladv=Ex,y[logDy(x)]+Ex,y~,z[log(1Dy~(G(x,s~)))](1)

  • 风格重构(衡量的是想要的风格和生成出的风格的相似程度)
    L sty  = E x , y ~ , z [ ∥ s ~ − E y ~ ( G ( x , s ~ ) ) ∥ 1 ] . (2) \mathcal{L}_{\text {sty }}=\mathbb{E}_{\mathbf{x}, \widetilde{y}, \mathbf{z}}\left[\left\|\widetilde{\mathbf{s}}-E_{\widetilde{y}}(G(\mathbf{x}, \widetilde{\mathbf{s}}))\right\|_{1}\right] . \tag{2} Lsty =Ex,y ,z[s Ey (G(x,s ))1].(2)

  • 风格多样化(衡量的是两个风格生成的图片的风格的相似性程度)
    L d s = E x , y ~ , z 1 , z 2 [ ∥ G ( x , s ~ 1 ) − G ( x , s ~ 2 ) ∥ 1 ] (3) \mathcal{L}_{d s}=\mathbb{E}_{\mathbf{x}, \widetilde{y}, \mathbf{z}_{1}, \mathbf{z}_{2}}\left[\left\|G\left(\mathbf{x}, \widetilde{\mathbf{s}}_{1}\right)-G\left(\mathbf{x}, \widetilde{\mathbf{s}}_{2}\right)\right\|_{1}\right] \tag{3} Lds=Ex,y ,z1,z2[G(x,s 1)G(x,s 2)1](3)

  • cycle consistency(衡量的是是否生成的图片不是随机乱生成的,即生成的图片还原成原图和原图之间的相似程度)
    L cyc  = E x , y , y ~ , z [ ∥ x ~ − G ( G ( x , s ~ ) , s ^ ) ∥ 1 ] (4) \mathcal{L}_{\text {cyc }}=\mathbb{E}_{\mathbf{x},\mathbf{y}, \widetilde{y}, \mathbf{z}}\left[\left\|\widetilde{\mathbf{x}}-G(G(\mathbf{x}, \widetilde{\mathbf{s}}),\hat{\mathbf{s}})\right\|_{1}\right] \tag{4} Lcyc =Ex,y,y ,z[x G(G(x,s ),s^)1](4)

  • 总loss
    m i n G , F , E    m a x D L cyc  + λ sty L sty  − λ ds L ds  + λ cyc L cyc  (5) \underset{G,F,E}{min} \; \underset{D}{max} \quad \mathcal{L}_{\text {cyc }}+\lambda_{\text{sty}} \mathcal{L}_{\text {sty }}-\lambda_{\text{ds}} \mathcal{L}_{\text {ds }}+\lambda_{\text{cyc}} \mathcal{L}_{\text {cyc }} \tag{5} G,F,EminDmaxLcyc +λstyLsty λdsLds +λcycLcyc (5)

训练实例

在这里插入图片描述

预测实例

在这里插入图片描述

video_ref

总结与启示

1.模型

starGAN
v2相较于v1来说网络主要结构都进行了很大的改进,v1需要固定的两个domain之间需要2个生成器网络和2个判别器网络,如果需要生成的domain数为n个就需要4*n个网络需要训练,显然是不合理的,最大的问题是描述需要生成的类的信息时使用了one-hot编码这样的编码没有学习的意义并且指定了只能生成这几个类不能让模型自己学习图片中的风格。v2中引入了mapping
network和style encoder来自动学习图片中的风格信息并且生成了有学习意义的特征编码style
coding。v2中的下采样的特征压缩主要是通过平均池化的方法,如果全部使用卷积层是否会得到更好的效果?另外HPF的边缘提取和mask增强可否用self-attention来聚焦关键部位的特征提取?

2.图像风格迁移的前景

openai提出的dalle2(AI自动绘画工具)再一次定义了行业的top水平,基于diffusion扩散模型的dalle2做到了更好的效果,diffusion只有一个网络用来拟合出一个未知的推导不出的参数,其余都是通过数理统计的方式来生成需要的图片,而不是如GAN需要多个网络进行训练图片特征。
3.GAN的应用

GAN可以做很多的风格迁移的应用,我学习GAN的初衷就是做目标检测的实际项目时需要大量的多样化数据,限于一些数据的获取有限制,所以需要GAN来自动生成一些多样化的数据来增强模型的鲁棒性。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值