原理
Conditional GAN1,简称CGAN,为原始GAN的延伸。简单来说,对于生成器
G
G
G和鉴别器
D
D
D,他们的输入都多了一项
y
y
y(样本的标签),可表示为:
min
G
max
D
V
(
D
,
G
)
=
E
x
∼
p
data
(
x
)
[
log
D
(
x
∣
y
)
]
+
E
z
∼
p
z
(
z
)
[
log
(
1
−
D
(
G
(
z
∣
y
)
)
)
\min _{G} \max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x} \mid \boldsymbol{y})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z} \mid \boldsymbol{y})))
GminDmaxV(D,G)=Ex∼pdata (x)[logD(x∣y)]+Ez∼pz(z)[log(1−D(G(z∣y)))
除了输入略有不同外与原始GAN完全一致。
实现
代码参考DCGAN原理分析与pytorch实现与DCGAN Demo。
生成器与鉴别器的结构实际为DCGAN。
数据集
使用MNIST数据集
my_transforms = transforms.Compose([
transforms.Resize(opt.image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,),(0.5,)),
])
dataset = MNIST(root = 'dataset/',train = True, transform=my_transforms, download = True)
dataloader = DataLoader(dataset, batch_size = opt.batch_size, shuffle = True)
生成器
生成器的输入x
为64通道的噪声,label
为10通道的分类标签,在内部对其进行拼接。
class Generator(nn.Module):
def __init__(self, z_dim, num_classes):
super().__init__()
self.z_dim = z_dim
self.num_classes = num_classes
net = []
channels_in = [self.z_dim+self.num_classes, 512, 256, 128, 64]
channels_out = [512, 256, 128, 64, 1]
active = ["R", "R", "R", "R", "tanh"]
stride = [1, 2, 2, 2, 2]
padding = [0, 1, 1, 1, 1]
for i in range(len(channels_in)):
net.append(nn.ConvTranspose2d(in_channels=channels_in[i], out_channels=channels_out[i],
kernel_size=4, stride=stride[i], padding=padding[i], bias=False))
if active[i] == "R":
net.append(nn.BatchNorm2d(num_features=channels_out[i]))
net.append(nn.ReLU())
elif active[i] == "tanh":
net.append(nn.Tanh())
self.generator = nn.Sequential(*net)
def forward(self, x, label):
x = x.unsqueeze(2).unsqueeze(3)
label = label.unsqueeze(2).unsqueeze(3)
data = torch.cat(tensors=(x, label), dim=1)
out = self.generator(data)
return out
鉴别器
鉴别器的输入x
为单通道
64
×
64
64 \times 64
64×64图像,label
为10通道的分类标签,在内部对其进行拼接。
class Discriminator(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.num_classes = num_classes
net = []
channels_in = [1+self.num_classes, 64, 128, 256, 512]
channels_out = [64, 128, 256, 512, 1]
padding = [1, 1, 1, 1, 0]
active = ["LR", "LR", "LR", "LR", "sigmoid"]
for i in range(len(channels_in)):
net.append(nn.Conv2d(in_channels=channels_in[i], out_channels=channels_out[i],
kernel_size=4, stride=2, padding=padding[i], bias=False))
if i == 0:
net.append(nn.LeakyReLU(0.2))
elif active[i] == "LR":
net.append(nn.BatchNorm2d(num_features=channels_out[i]))
net.append(nn.LeakyReLU(0.2))
elif active[i] == "sigmoid":
net.append(nn.Sigmoid())
self.discriminator = nn.Sequential(*net)
def forward(self, x, label):
label = label.unsqueeze(2).unsqueeze(3)
label = label.repeat(1, 1, x.size(2), x.size(3))
data = torch.cat(tensors=(x, label), dim=1)
out = self.discriminator(data)
out = out.view(data.size(0), -1)
return out
训练
for epoch in range(opt.num_epochs):
for batch_idx, (data, targets) in enumerate(dataloader):
data = data.to(device)
# targets_temp = [torch.zeros(10) for _ in targets]
targets_temp = torch.zeros([len(targets),10])
for i in range(len(targets_temp)):
targets_temp[i][targets[i]] = 1
targets = targets_temp.to(device)
batch_size = data.shape[0]
### Train Discriminator: max log(D(x)) + log(1-D(G(z)))
netD.zero_grad()
label = (real_output * torch.ones(1,batch_size)).to(device)
output = netD(data, targets).reshape(-1)
lossD_real = criterion(output, label)
D_x = output.mean().item() # Mean confidence of the Discriminator on true imgs.
noise = torch.randn(batch_size, opt.channels_noise).to(device)
fake = netG(noise, targets)
label = (fake_output * torch.ones(1, batch_size)).to(device)
output = netD(fake.detach(), targets).reshape(-1)
lossD_fake = criterion(output, label)
lossD = lossD_real + lossD_fake
lossD.backward()
optimizerD.step()
### Train Generator: max log(D(G(z)))
netG.zero_grad()
label = torch.ones(batch_size).to(device)
output = netD(fake, targets).reshape(-1)
lossG = criterion(output, label)
lossG.backward()
optimizerG.step()
效果
在40个epoch后的生成比较令人满意,从第一个epoch开始就没有观测到明显的分类错误。可能是因为使用了DCGAN,输出效果明显比原文清晰。