Conditional GAN 浅析与实现

Conditional GAN 浅析与实现

原理

Conditional GAN1,简称CGAN,为原始GAN的延伸。简单来说,对于生成器 G G G和鉴别器 D D D,他们的输入都多了一项 y y y(样本的标签),可表示为:
min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p data  ( x ) [ log ⁡ D ( x ∣ y ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ∣ y ) ) ) \min _{G} \max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x} \mid \boldsymbol{y})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z} \mid \boldsymbol{y}))) GminDmaxV(D,G)=Expdata (x)[logD(xy)]+Ezpz(z)[log(1D(G(zy)))
除了输入略有不同外与原始GAN完全一致。

实现

代码参考DCGAN原理分析与pytorch实现DCGAN Demo
生成器与鉴别器的结构实际为DCGAN。

数据集

使用MNIST数据集

my_transforms = transforms.Compose([
  transforms.Resize(opt.image_size),
  transforms.ToTensor(),
  transforms.Normalize((0.5,),(0.5,)),
])

dataset = MNIST(root = 'dataset/',train = True, transform=my_transforms, download = True)
dataloader = DataLoader(dataset, batch_size = opt.batch_size, shuffle = True)

生成器

生成器的输入x为64通道的噪声,label为10通道的分类标签,在内部对其进行拼接。

class Generator(nn.Module):
    def __init__(self, z_dim, num_classes):
        super().__init__()
        self.z_dim = z_dim
        self.num_classes = num_classes
        net = []
        channels_in = [self.z_dim+self.num_classes, 512, 256, 128, 64]
        channels_out = [512, 256, 128, 64, 1]
        active = ["R", "R", "R", "R", "tanh"]
        stride = [1, 2, 2, 2, 2]
        padding = [0, 1, 1, 1, 1]
        for i in range(len(channels_in)):
            net.append(nn.ConvTranspose2d(in_channels=channels_in[i], out_channels=channels_out[i],
                                          kernel_size=4, stride=stride[i], padding=padding[i], bias=False))
            if active[i] == "R":
                net.append(nn.BatchNorm2d(num_features=channels_out[i]))
                net.append(nn.ReLU())
            elif active[i] == "tanh":
                net.append(nn.Tanh())

        self.generator = nn.Sequential(*net)

    def forward(self, x, label):
        x = x.unsqueeze(2).unsqueeze(3)
        label = label.unsqueeze(2).unsqueeze(3)
        data = torch.cat(tensors=(x, label), dim=1)
        out = self.generator(data)
        return out

鉴别器

鉴别器的输入x为单通道 64 × 64 64 \times 64 64×64图像,label为10通道的分类标签,在内部对其进行拼接。

class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes

        net = []
        channels_in = [1+self.num_classes, 64, 128, 256, 512]
        channels_out = [64, 128, 256, 512, 1]
        padding = [1, 1, 1, 1, 0]
        active = ["LR", "LR", "LR", "LR", "sigmoid"]
        for i in range(len(channels_in)):
            net.append(nn.Conv2d(in_channels=channels_in[i], out_channels=channels_out[i],
                                 kernel_size=4, stride=2, padding=padding[i], bias=False))
            if i == 0:
                net.append(nn.LeakyReLU(0.2))
            elif active[i] == "LR":
                net.append(nn.BatchNorm2d(num_features=channels_out[i]))
                net.append(nn.LeakyReLU(0.2))
            elif active[i] == "sigmoid":
                net.append(nn.Sigmoid())
        self.discriminator = nn.Sequential(*net)

    def forward(self, x, label):
        label = label.unsqueeze(2).unsqueeze(3)
        label = label.repeat(1, 1, x.size(2), x.size(3))
        data = torch.cat(tensors=(x, label), dim=1)
        out = self.discriminator(data)
        out = out.view(data.size(0), -1)
        return out

训练

for epoch in range(opt.num_epochs):
  for batch_idx, (data, targets) in enumerate(dataloader):
    data = data.to(device)
    # targets_temp = [torch.zeros(10) for _ in targets]
    targets_temp = torch.zeros([len(targets),10])
    for i in range(len(targets_temp)):
      targets_temp[i][targets[i]] = 1
    targets = targets_temp.to(device)
    batch_size = data.shape[0]
    ### Train Discriminator: max log(D(x)) + log(1-D(G(z)))
    netD.zero_grad()
    label = (real_output * torch.ones(1,batch_size)).to(device)

    output = netD(data, targets).reshape(-1)
    lossD_real = criterion(output, label)
    D_x = output.mean().item() # Mean confidence of the Discriminator on true imgs.

    noise = torch.randn(batch_size, opt.channels_noise).to(device)
    fake = netG(noise, targets)
    label = (fake_output * torch.ones(1, batch_size)).to(device)

    output = netD(fake.detach(), targets).reshape(-1)
    lossD_fake = criterion(output, label)

    lossD = lossD_real + lossD_fake
    lossD.backward()
    optimizerD.step()

    ### Train Generator: max log(D(G(z)))
    netG.zero_grad()
    label = torch.ones(batch_size).to(device)
    output = netD(fake, targets).reshape(-1)
    lossG = criterion(output, label)
    lossG.backward()
    optimizerG.step()

效果

生成样本

在40个epoch后的生成比较令人满意,从第一个epoch开始就没有观测到明显的分类错误。可能是因为使用了DCGAN,输出效果明显比原文清晰。


  1. Mirza, M. and Osindero, S., 2014. Conditional generative adversarial nets. arXiv preprint arXiv:1411.1784. ↩︎

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值