starGAN原理代码分析

Pix2Pix模型解决了有Pair对数据的图像翻译问题;CycleGAN解决了Unpaired数据下的图像翻译问题。但无论是Pix2Pix还是CycleGAN,都是解决了一对一的问题,即一个领域到另一个领域的转换。当有很多领域要转换了,对于每一个领域转换,都需要重新训练一个模型去解决。这样的行为太低效了。本文所介绍的StarGAN就是将多领域转换用统一框架实现的算法。

下图是StarGAN的效果,在同一种模型下,可以做多个图像翻译任务,比如更换头发颜色,更换表情,更换年龄等。

引入
如果只能训练一对一的图像翻译模型,会导致两个问题:

训练低效,每次训练耗时很大。
训练效果有限,因为一个领域转换单独训练的话就不能利用其它领域的数据来增大泛化能力。
为了解决多对多的图像翻译问题,StarGAN出现了。

模型框架
StarGAN,顾名思义,就是星形网络结构,在StarGAN中,生成网络G被实现成星形。如下图所示,左侧为普通的Pix2Pix模型要训练多对多模型时的做法,而右侧则是StarGAN的做法,可以看到,StarGAN仅仅需要一个G来学习所有领域对之间的转换。

那么,是什么让G有这样的能力呢?

网络结构
要想让G拥有学习多个领域转换的能力,需要对生成网络G和判别网络D做如下改动。

在G的输入中添加目标领域信息,即把图片翻译到哪个领域这个信息告诉生成模型。
D除了具有判断图片是否真实的功能外,还要有判断图片属于哪个类别的能力。这样可以保证G中同样的输入图像,随着目标领域的不同生成不同的效果
除了上述两样以外,还需要保证图像翻译过程中图像内容要保存,只改变领域差异的那部分。图像重建可以完整这一部分,图像重建即将图像翻译从领域A翻译到领域B,再翻译回来,不会发生变化。
D的训练和G的训练如下所示。

目标函数
首先是GAN的通用函数,判断输出图像是否真实

其次是类别损失,该损失被分成两个,训练D的时候,使用真实图像在原始领域进行,训练G的时候,使用生成的图像在目标领域进行。

训练D的损失:

训练G的损失:

再次则是重建函数,重建函数与CycleGAN中的正向函数类似。

汇总后则是

多数据集训练
在多数据集下训练StarGAN存在一个问题,那就是数据集之间的类别可能是不相交的,但内容可能是相交的。比如CelebA数据集合RaFD数据集,前者拥有很多肤色,年龄之类的类别。而后者拥有的是表情的类别。但前者的图像很多也是有表情的,这就导致前一类的图像在后一类的标记是不可知的。

为了解决这个问题,在模型输入中加入了Mask,即如果来源于数据集B,那么将数据集A中的标记全部设为0.


效果图


更多请参考原始论文.

Reference
[1]. StarGAN: Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation
[2]. Pix2Pix图像翻译
[3]. CycleGAN-Unpaired图像翻译
--------------------- 

下载:

git clone https://github.com/yunjey/StarGAN.git
1
cd StarGAN/
1
下载celebA训练数据:

bash download.sh
1
训练:

python main.py --mode='train' --dataset='CelebA' --c_dim=5 --image_size=128 \
                 --sample_path='stargan_celebA/samples' --log_path='stargan_celebA/logs' \
                 --model_save_path='stargan_celebA/models' --result_path='stargan_celebA/results'
1
2
3
代码分析
生成网络
第一个卷积层,输入为图像和label的串联,3表示图像为3通道,c_dim为label的维度,

layers = []
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.InstanceNorm2d(conv_dim, affine=True))
layers.append(nn.ReLU(inplace=True))
1
2
3
4
2个卷积层,stride=2,即下采样,

# Down-Sampling
curr_dim = conv_dim
for i in range(2):
    layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
    layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True))
    layers.append(nn.ReLU(inplace=True))
    curr_dim = curr_dim * 2
1
2
3
4
5
6
7
残差层,

# Bottleneck
for i in range(repeat_num):
    layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
1
2
3
残差网络结构,

class ResidualBlock(nn.Module):
    """Residual Block."""
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True))

    def forward(self, x):
        return x + self.main(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
上采样,

# Up-Sampling
for i in range(2):
    layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
    layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True))
    layers.append(nn.ReLU(inplace=True))
    curr_dim = curr_dim // 2
1
2
3
4
5
6
最后一层,得到输出维度为3,

layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.Tanh())
1
2
self.main = nn.Sequential(*layers)
1
对于输入图像x,label向量c,串联如下,

def forward(self, x, c):
    # replicate spatially and concatenate domain information
    c = c.unsqueeze(2).unsqueeze(3)
    c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3))
    x = torch.cat([x, c], dim=1)
    return self.main(x)
1
2
3
4
5
6
判别网络
判别网络输入为图像,用于判别输入图像真假,已经输入图像的类别,

class Discriminator(nn.Module):
    """Discriminator. PatchGAN."""
    def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
        super(Discriminator, self).__init__()

        layers = []
        layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01, inplace=True))

        curr_dim = conv_dim
        for i in range(1, repeat_num):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            curr_dim = curr_dim * 2

        k_size = int(image_size / np.power(2, repeat_num))
        self.main = nn.Sequential(*layers)
        self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=k_size, bias=False)

    def forward(self, x):
        h = self.main(x)
        out_real = self.conv1(h)
        out_aux = self.conv2(h)
        return out_real.squeeze(), out_aux.squeeze()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
conv1输出维度为1,即判别输入的真假,conv2输出维度为c_dim,即判别输入图像的label.

训练数据,损失函数,参数更新,
输入包括

real_x,real_c,fake_c

fake_c为随机生成的,

# Generat fake labels randomly (target domain labels)
rand_idx = torch.randperm(real_label.size(0))
fake_label = real_label[rand_idx]
if self.dataset == 'CelebA':
                    real_c = real_label.clone()
                    fake_c = fake_label.clone()
1
2
3
4
5
6
训练判别网络
将真实图像输入判别网络,

# Compute loss with real images
out_src, out_cls = self.D(real_x)
d_loss_real = - torch.mean(out_src)
1
2
3
判别网络的输入为真实图像,输出out_cls为真实图像对应的标签的概率,则可以计算交叉损失熵,

if self.dataset == 'CelebA':
    d_loss_cls = F.binary_cross_entropy_with_logits(
        out_cls, real_label, size_average=False) / real_x.size(0)
1
2
3
将真实图像输入real_x和假的标签fake_c输入生成网络,得到生成图像fake_x,

fake_x = self.G(real_x, fake_c)
1
将生成图像输入判别网络,

fake_x = Variable(fake_x.data)
out_src, out_cls = self.D(fake_x)
d_loss_fake = torch.mean(out_src)
1
2
3
总的损失函数为,

# Backward + Optimize
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
1
2
根据d_loss更新判别网络参数,

# Backward + Optimize
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
1
2
3
4
5
计算梯度惩罚因子alpha,根据alpha结合real_x,fake_x,输入判别网络,计算梯度,得到梯度损失函数,

# Compute gradient penalty
alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
out, out_cls = self.D(interpolated)

grad = torch.autograd.grad(outputs=out,
                           inputs=interpolated,
                           grad_outputs=torch.ones(out.size()).cuda(),
                           retain_graph=True,
                           create_graph=True,
                           only_inputs=True)[0]

grad = grad.view(grad.size(0), -1)
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
d_loss_gp = torch.mean((grad_l2norm - 1)**2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
根据梯度损失函数d_loss_gp优化判别网路,

# Backward + Optimize
d_loss = self.lambda_gp * d_loss_gp
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
1
2
3
4
5
训练生成网络
生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像.

将原图像输入生成网络,得到生成图像fake_x,同时将fake_x图像输入生成网络,希望生成的图像与真实图像尽量相似,

# Original-to-target and target-to-original domain
fake_x = self.G(real_x, fake_c)
rec_x = self.G(fake_x, real_c)
# Compute losses
g_loss_rec = torch.mean(torch.abs(real_x - rec_x))
1
2
3
4
5
将fake_x输入判别网路,

out_src, out_cls = self.D(fake_x)
g_loss_fake = - torch.mean(out_src)
1
2
计算损失函数,

g_loss_fake = - torch.mean(out_src)
1
对于fake_x,对应的label为fake_label,将fake_x输入判别网络,判别网络预测label概率为out_cls,因此可以计算交叉损失熵,

g_loss_cls = F.binary_cross_entropy_with_logits(
    out_cls, fake_label, size_average=False) / fake_x.size(0)
1
2
生成网络参数更新,

# Backward + Optimize
g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
self.reset_grad()
g_loss.backward()
self.g_optimizer.step()
1
2
3
4
5
训练数据处理
以celebA数据为例,下载后的数据包括label文件,和图像.

文件的第一行为图像的总数,为202599.

第二行为数据处理的类别,包括40种,

5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young

第三行及之后的每行为,图像名,已经对应的40种类别的label,label值为1或-1,之后提取为值1为1,-1为0.

000001.jpg -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1

list_attr_celeba.txt文件提取函数为,

def preprocess(self):
    attrs = self.lines[1].split()
    for i, attr in enumerate(attrs):
        self.attr2idx[attr] = i
        self.idx2attr[i] = attr

    self.selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
    self.train_filenames = []
    self.train_labels = []
    self.test_filenames = []
    self.test_labels = []

    lines = self.lines[2:]#the image and labels
    random.shuffle(lines)   # random shuffling
    for i, line in enumerate(lines):

        splits = line.split()
        filename = splits[0]#image name
        values = splits[1:]# labels

        label = []
        for idx, value in enumerate(values):
            attr = self.idx2attr[idx]# there are 40 classes,find the idx th class name

            if attr in self.selected_attrs:#check if the attr in the selected classes
                if value == '1':#if the ckss label is 1 then label equal 2,otherwise,0
                    label.append(1)
                else:
                    label.append(0)

        if (i+1) < 2000:
            self.test_filenames.append(filename)
            self.test_labels.append(label)
        else:
            self.train_filenames.append(filename)
            self.train_labels.append(label)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
self.selected_attrs表示我们训练选用的任务类别集合.最后得到图像名数组self.train_filenames,及其对应的label数组 self.train_labels.

之后采用from torch.utils.data import DataLoader加载训练数据,

data_loader = DataLoader(dataset=dataset,
                         batch_size=batch_size,
                         shuffle=shuffle)
1
2
3
fixed_x = []
real_c = []
for i, (images, labels) in enumerate(self.data_loader):
    fixed_x.append(images)
    real_c.append(labels)
    if i == 3:
        break
1
2
3
4
5
6
7
读取后的图像数组为fixed_x,lable为real_c.图像为(bath_size,c_dim,imagesize,imagesize),label为(batch_size,len(self.selected_attrs)).

得到固定的输入图像数组,label,labelist,用于sample.

# Fixed inputs and target domain labels for debugging
fixed_x = torch.cat(fixed_x, dim=0)#4*batch_szie,(64,3,128,128)
fixed_x = self.to_var(fixed_x, volatile=True)
real_c = torch.cat(real_c, dim=0)

if self.dataset == 'CelebA':
    fixed_c_list = self.make_celeb_labels(real_c)
1
2
3
4
5
6
7
labellist生成函数为,

def make_celeb_labels(self, real_c):
    """Generate domain labels for CelebA for debugging/testing.

    if dataset == 'CelebA':
        return single and multiple attribute changes
    elif dataset == 'Both':
        return single attribute changes
    """
    y = [torch.FloatTensor([1, 0, 0]),  # black hair
         torch.FloatTensor([0, 1, 0]),  # blond hair
         torch.FloatTensor([0, 0, 1])]  # brown hair

    fixed_c_list = []

    # single attribute transfer
    for i in range(self.c_dim):
        fixed_c = real_c.clone()
        for c in fixed_c:
            if i < 3:
                c[:3] = y[i]
            else:
                c[i] = 0 if c[i] == 1 else 1   # opposite value
        fixed_c_list.append(self.to_var(fixed_c, volatile=True))

    # multi-attribute transfer (H+G, H+A, G+A, H+G+A)
    if self.dataset == 'CelebA':
        for i in range(4):
            fixed_c = real_c.clone()
            for c in fixed_c:
                if i in [0, 1, 3]:   # Hair color to brown
                    c[:3] = y[2] 
                if i in [0, 2, 3]:   # Gender
                    c[3] = 0 if c[3] == 1 else 1
                if i in [1, 2, 3]:   # Aged
                    c[4] = 0 if c[4] == 1 else 1
            fixed_c_list.append(self.to_var(fixed_c, volatile=True))
    return fixed_c_list
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
fixed_c_list长度为c_dim+4=5+4=9,

训练的时候,fake_label为随机产生0-batch_size的索引,并由索引,从real_label取值,

# Start training
start_time = time.time()
for e in range(start, self.num_epochs):
    for i, (real_x, real_label) in enumerate(self.data_loader):

        # Generat fake labels randomly (target domain labels)
        rand_idx = torch.randperm(real_label.size(0))
        fake_label = real_label[rand_idx]

        if self.dataset == 'CelebA':
            real_c = real_label.clone()
            fake_c = fake_label.clone()
        else:
            real_c = self.one_hot(real_label, self.c_dim)
            fake_c = self.one_hot(fake_label, self.c_dim)

        # Convert tensor to variable
        real_x = self.to_var(real_x)#(16,3,128,128)
        real_c = self.to_var(real_c) #(16,5)          # input for the generator
        fake_c = self.to_var(fake_c)#(16,5)
        real_label = self.to_var(real_label)   # this is same as real_c if dataset == 'CelebA'
        fake_label = self.to_var(fake_label)
--------------------- 
作者:imperfect00 
来源:CSDN 
原文:https://blog.csdn.net/u011961856/article/details/78697863 
版权声明:本文为博主原创文章,转载请附上博文链接!

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值