辅助分类器生成对抗网络( Auxiliary Classifier Generative Adversarial Network,ACGAN)(附pytorch代码)

1 ACGAN基本原理

1.2 ACGAN模型解释

ACGAN相对于CGAN使的判别器不仅可以判别真假,也可以判别类别 。通过对生成数据类别的判断,判别器可以更好地传递loss函数使得生成器能够更加准确地找到label对应的噪声分布,通过下图告诉了我们ACGAN与CGAN的异同之处 :

1721649945655

对于CGAN和ACGAN,生成器输入均为潜在矢量及其标签,输出是属于输入类标签的伪造数据。对于CGAN,判别器的输入是数据(包含假的或真实的数据)及其标签, 输出是图像属于真实数据的概率。对于ACGAN,判别器的输入是数据,而输出是该图像属于真实数据的概率以及其类别概率。

在ACGAN中,对于生成器来说有两个输入,一个是标签的分类数据c,另一个是随机数据z,得到生成数据为 X_{fake}=G(photo/gif.gif) ;对于判别器,产生跨域标签和源数据的概率分布

1.2 ACGAN损失函数

对于判别器而言,即希望分类正确,有希望能正确分辨数据的真假;对于生成器而言,也希望分类正确,但希望判别器不能正确分辨真假。因此在训练判别器的时候,我们希望LSE+LCS最大化;在训练生成器的时候,我们希望LCS-LSE最大化。

在这里插入图片描述

  • logP(SR = real | Xreal)表示鉴别器将真实样本源正确分类为真实样本的对数似然;
  • logP(SR = fake | Xfake)表示鉴别器正确地将假样本的来源分类为假样本的对数似然
  • E[.]表示所有样本的平均值
  • logP(CS = CS | Xreal)表示鉴别器正确分类真实样本的对数似然
  • logP(CS = CS | Xfake)表示鉴别器正确分类具有正确类别标签的假样本的对数似然

判别器的损失函数 = LSE + LCS;生成器的损失函数 = LCS - LSE

  • LSE测量鉴别器正确区分样本是真还是假的程度。这有助于鉴别器熟练地识别来源(真实的或生成的)。
  • LCS确保生成的样本不仅看起来真实,而且携带正确的类信息。它引导生成器在不同的类中产生多样化和现实的样本。

2 ACGAN pytorch代码实现

完整代码链接:https://github.com/znxlwm/pytorch-generative-model-collections/tree/master

(但是这个代码我训练的时候损失函数也对应的上,得到的图片是黑乎乎的一片,也不知道是什么原因,如果知道的师傅可以麻烦告知一下吗?)

这个代码在训练ACGAN模型的时候加载数据集的时候会出现问题,因为我使用的是minist数据集,所以应该改为单通道的:

1721782164361

import utils, torch, time, os, pickle
import numpy as np
import torch.nn as nn
import torch.optim as optim
from dataloader import dataloader

class generator(nn.Module):
    # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
    def __init__(self, input_dim=100, output_dim=1, input_size=32, class_num=10):
        super(generator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size
        self.class_num = class_num

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim + self.class_num, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
            nn.ReLU(),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Tanh(),
        )
        utils.initialize_weights(self)

    def forward(self, input, label):
        x = torch.cat([input, label], 1)
        x = self.fc(x)
        x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
        x = self.deconv(x)

        return x

class discriminator(nn.Module):
    # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
    def __init__(self, input_dim=1, output_dim=1, input_size=32, class_num=10):
        super(discriminator, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.input_size = input_size
        self.class_num = class_num

        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.fc1 = nn.Sequential(
            nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
        )
        self.dc = nn.Sequential(
            nn.Linear(1024, self.output_dim),
            nn.Sigmoid(),
        )
        self.cl = nn.Sequential(
            nn.Linear(1024, self.class_num),
        )
        utils.initialize_weights(self)

    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
        x = self.fc1(x)
        d = self.dc(x)
        c = self.cl(x)

        return d, c

class ACGAN(object):
    def __init__(self, args):
        # parameters
        self.epoch = args.epoch
        self.sample_num = 100
        self.batch_size = args.batch_size
        self.save_dir = args.save_dir
        self.result_dir = args.result_dir
        self.dataset = args.dataset
        self.log_dir = args.log_dir
        self.gpu_mode = args.gpu_mode
        self.model_name = args.gan_type
        self.input_size = args.input_size    # 输入图像的尺寸
        self.z_dim = 62    # 潜在向量维度
        self.class_num = 10
        self.sample_num = self.class_num ** 2   # 总样本的数量

        # load dataset
        self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)   # 加载数据集
        data = self.data_loader.__iter__().__next__()[0]    # 获得第一个批次的数据,data 的形状通常是 (batch_size, channels, height, width)


        # networks init
        self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size)
        self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size)
        self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
        self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))

        # 查看是否启用了gpu模式
        if self.gpu_mode:
            self.G.cuda()
            self.D.cuda()
            self.BCE_loss = nn.BCELoss().cuda()     # 将交叉熵损失加载到GPU
            self.CE_loss = nn.CrossEntropyLoss().cuda()     # 将二元交叉熵损失加载到GPU
        else:
            self.BCE_loss = nn.BCELoss()
            self.CE_loss = nn.CrossEntropyLoss()

        print('---------- Networks architecture -------------')
        utils.print_network(self.G)
        utils.print_network(self.D)
        p
  • 9
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值