[CVPR2020] StarGAN v2

本文详细介绍了StargAN v2的改进,包括将原本的attribute概念替换为domain和style,并提出了无需手动标注的风格转换方法。通过多任务鉴别器、R1正则和映射网络提升稳定性与多样性。结构上,Generator、Discriminator、Style Encoder和Mapping Network的改进增强了模型性能。训练过程和损失函数的调整确保了生成图像的质量和多样性。此外,FID和LPIPS作为评价指标,展示了模型的出色表现。
摘要由CSDN通过智能技术生成

网络改进

StarGAN v1 中对attribute、domain的定义

We denote the terms attribute as a meaningful feature inherent in an image such as hair color, gender or age, and attribute value as a particular value of an attribute, e.g., black/blond/brown for hair color or male/female for gender. We further denote domain as a set of images sharing the same attribute value. For example, images of women can represent one domain while those of men represent another

StarGAN v2 中对domain、style的定义

domain implies a set of images that can be grouped as a visually distinctive category, and each image has a unique appearance, which we call style. For example, we can set image domains based on the gender of a person, in which case the style in- cludes makeup, beard, and hairstyle

Stargan v1 结构如下:
在这里插入图片描述
StarGAN v1 将有同样一个 attribute value 的一组图片作为一个的 domain。以CelebA为例,其attribute包括hair color(attribute values 有 black/blond/brown)、gender(attribute values 有 male/female)等。

问题在于,1、StarGAN 风格转换的图像部分很局限,多样性差;2、这里的attribute需要人工标出,没标出就无法学习,当存在多种style或domain时,很棘手。比如有一组全新domain的图片,你需要将你的图片转换成他的风格,那你需要单独标出.

StarGAN 改进版本,不需要具体标出style标签(attribute),只需1、输入源domain的图像,以及目标domain的一张指定参考图像(Style Encoder网络学习其style code),就可将源图像转换成 目标domain+参考图像style 的迁移图像;或者2、输入源domain的图像,以及随机噪声(mapping网络将其映射为指定domain的随机style code),就可将源图像转换成 目标domain+随机style 的迁移图像
Stargan v2 结构如下:

在这里插入图片描述
改进过程如下表:
在这里插入图片描述

基于(A)StarGAN,改进尝试如下,每点改进效果见下图:

  • (B)将原ACGAN+PatchGAN的鉴别器 换成 多任务鉴别器,使生成器能转换全局结构。
  • (C)引入R1正则与AdIN增加稳定度
  • (D)直接引入潜变量z增加多样性(无法有效,只能改变某一固定区域,而不是全局)
  • (E)将(D)的改进换成 引入映射网络,输出为每个domain的style code
  • (F)多样性正则

在这里插入图片描述

具体结构

Generator

对AFHQ数据集如下,4个下采样块,4个中间块以及4个上采样块,如下表所示。对CelebA HQ,下采样以及上采样块数加一。
在这里插入图片描述
其结构图如下:
在这里插入图片描述
其代码如下:

class Generator(nn.Module):
    def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, w_hpf=1):
        super().__init__()
        dim_in = 2**14 // img_size
        self.img_size = img_size
        self.from_rgb = nn.Conv2d(3, dim_in, 3, 1, 1)
        self.encode = nn.ModuleList()
        self.decode = nn.ModuleList()
        self.to_rgb = nn.Sequential(
            nn.InstanceNorm2d(dim_in, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(dim_in, 3, 1, 1, 0))

        # down/up-sampling blocks
        repeat_num = int(np.log2(img_size)) - 4
        if w_hpf > 0: #weight for high-pass filtering
            repeat_num += 1
        for _ in range(repeat_num):
            dim_out = min(dim_in*2, max_conv_dim)
            self.encode.append(
                ResBlk(dim_in, dim_out, normalize=True, downsample=True))
            self.decode.insert(
                0, AdainResBlk(dim_out, dim_in, style_dim,
                               w_hpf=w_hpf, upsample=True))  # stack-like
            dim_in = dim_out

        # bottleneck blocks
        for _ in range(2):
            self.encode.append(
                ResBlk(dim_out, dim_out, normalize=True))
            self.decode.insert(
                0, AdainResBlk(dim_out, dim_out, style_dim, w_hpf=w_hpf))

        if w_hpf > 0:
            device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')
            self.hpf = HighPass(w_hpf, device)

    def forward(self, x, s, masks=None):
        x = self.from_rgb(x)
        cache = {
   }
        for block in self.encode:
            if (masks is not None) and (x.size(2) in [32, 64, 128]):
                cache[x.size(2)] = x
            x = block(x)
        for block in self.decode:
            x = block(x, s)
            if (masks is not None) and (x.size(2) in [32, 64, 128]):
                mask = masks[0] if x.size(2) in [32] else masks[1]
                mask = F.interpolate(mask, size=x.size(2), mode='bilinear')
                x = x + self.hpf(mask * cache[x.size(2)])
        return self.to_rgb(x)
        
class AdaIN(nn.Module):
    def __init__(self, style_dim, num_features):
        super().__init__()
        self.norm = nn.InstanceNorm2d(num_features, affine=False)
        self.fc = nn.Linear(style_dim, num_features*2)

    def forward(self, x, s):
        h = self.fc(s)
        h = h.view(h.size(0), h.size(1), 1, 1)
        gamma, beta = torch.chunk(h, chunks=2, dim=1) ## 分成两块
        return (1 + gamma) * self.norm(x) + beta       
        
class ResBlk(nn.Module):
    def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
                 normalize=False, downsample=False):
        super().__init__()
        self.actv = actv
        self.normalize = normalize
        self.downsample = downsample
        self.learned_sc = dim_in != dim_out
        self._build_weights(dim_in, dim_out)

    def _build_weights(self, dim_in, dim_out):
        self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
        if self.normalize:
            self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
            self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
        if self.learned_sc:
            self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)

    def _shortcut(self, x):
        if self.learned_sc:
            x = self.conv1x1(x)
        if self.downsample:
            x = F.avg_pool2d(x, 2)
        return x

    def _residual(self, x):
        if self.normalize:
            x = self.norm1(x)
        x = self.actv(x)
        x = self.conv1(x)
        if self.downsample:
            x = F.avg_pool2d(x, 2)
        if self.normalize:
            x = self.norm2(x)
        x = self.actv(x)
        x = self.conv2(x)
        return x

    def forward(self, x):
        x = self._shortcut(x) + self._residual(x)
        return x / math.sqrt(2)  # unit variance ***

class AdainResBlk(nn.Module):
    def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=0,
                 actv=nn.LeakyReLU(0.2), upsample=False):
        super().__init__()
        self.w_hpf = w_hpf
        self.actv = actv
        self.upsample = upsample
        self.learned_sc = dim_in != dim_out
        self._build_weights(dim_in, dim_out, style_dim)

    def _build_weights(self, dim_in, dim_out, style_dim=64):
        self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
        self.norm1 = AdaIN(style_dim, dim_in)
        self.norm2 = AdaIN(style_dim, dim_out)
        if self.learned_sc:
            self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)

    def _shortcut(self, x):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        if self.learned_sc:
            x = self.conv1x1(x)
        return x

    def _residual(self, x, s):
        x = self.norm1(x, s)
        x = self.actv(x)
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = self.conv1(x)
        x = self.norm2(x, s)
        x = self.actv(x)
        x = self.conv2(x)
        return x

    def forward(self, x, s):
        out = self._residual(x, s)
        if self.w_hpf == 0:
            out = (out + self._shortcut(x)) / math.sqrt(2)
        return out
        
class HighPass(nn.Module):
    def __init__(self, w_hpf, device):
        super(HighPass, self).__init__()
        self.filter = torch.tensor([[-1, -1, -1],
                                    [-1, 8., -1],
                                    [-1, -1, -1]]).to(device) / w_hpf

    def forward(self, x):
        filter = self.filter.unsqueeze(0).unsqueeze(1).repeat(x.size(1), 1, 1, 1)
        return F.conv2d(x, filter, padding=1, groups=x.size(1))

其中HighPass相当于一个边缘提取网络,我写了一个测试如下:

img = cv2.imread('celeb.png')
img_ =torch.from_numpy((img)).float().unsqueeze(0).permute(0,3,1,2)
print(img_.shape)
hpf = HighPass(1,'cpu')
out = hpf(img_).permute(0,2,3,1).numpy()
plt.subplot(121)
plt.imshow(img[:,:,::-1])
plt.subplot(122)
plt.imshow(out[0][:,:,::-1])
plt.show()

HighPass Filter 的处理效果如下:
在这里插入图片描述

Discriminator

在这里插入图片描述
其代码如下:

class Discriminator(nn.Module):
    def __init__(self, img_size=256, num_domains=2, max_conv_dim=512):
        super().__init__()
        dim_in = 2**14 // img_size
        blocks = []
        blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]

        repeat_num = int(np.log2(img_size)) - 2
        for _ in range(repeat_num
评论 37
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值