一. cgan
二. 训练过程
1. Train Generator
(1)定义valid 和fake,定义real_imgs和labels
valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)
(2)
optimizer_G.zero_grad( )
(3)随机生成 z 与 gen_labels(0~9)
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
gen_labels =Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size) ))
(4)生成器生成图片
gen_imgs = generator(z, gen_labels)
(5)计算生成器loss
validity = discriminator(gen_imgs, gen_labels)
g_loss = adversarial_loss(validity, valid)
(6)更新G
g_loss.backward()
optimizer_G.step()
2. Train Discriminator
(1)
optimizer_D.zero_grad( )
(2)计算假图片的loss
validity_real = discriminator(real_imgs, labels)
d_real_loss = adversarial_loss(validity_real, valid)
(3)计算真图片的loss
validity_fake = discriminator(gen_imgs.detach(), gen_labels)
d_fake_loss = adversarial_loss(validity_fake, fake)
(4)计算总loss
d_loss = (d_real_loss + d_fake_loss) / 2
(5)更新D
d_loss.backward()
optimizer_D.step()
3. 一些函数
1. numpy.random.randint(low, high, size)
low、high、size三个参数。默认high是None,如果只有low,那范围就是[0,low)。如果有high,范围就是[low,high)
>>> a=np.random.randint(0, 10, 64)
>>> a
array([3, 0, 2, 2, 9, 8, 7, 8, 0, 0, 9, 5, 2, 0, 4, 6, 4, 9, 8, 7, 0, 9,
2, 6, 1, 3, 5, 3, 8, 5, 3, 9, 6, 6, 3, 7, 9, 6, 8, 4, 5, 2, 0, 0,
4, 0, 1, 8, 1, 7, 0, 4, 3, 8, 5, 4, 4, 6, 8, 2, 2, 9, 3, 8])
2. numpy.random.normal(loc=0.0, scale=1.0, size=None)
loc:float
此概率分布的均值(对应着整个分布的中心centre)
scale:float
此概率分布的标准差(对应于分布的宽度,scale越大越矮胖,scale越小,越瘦高)
size:int or tuple of ints
输出的shape,默认为None,只输出一个值
三. 保存图片
def sample_image(n_row, batches_done):
"""Saves a grid of generated digits ranging from 0 to n_classes"""
# Sample noise
z = Variable(FloatTensor(np.random.normal(0, 1, (n_row**2, opt.latent_dim))))
# Get labels ranging from 0 to n_classes for n rows
labels = np.array([num for _ in range(n_row) for num in range(n_row)])
labels = Variable(LongTensor(labels))
gen_imgs = generator(z, labels)
save_image(gen_imgs.data, 'images/%d.png' % batches_done, nrow=n_row, normalize=True)
四. 结果展示
源代码网址:https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/cgan