1.简单的GAN训练流程
1训练流程
- 创建标签,判别器Discriminator区分Real Image和Fake Image归根结底是一个二分类。这里不能用数据集自带标签。
valid = Tensor(image.size(0), 1).fill_(1.0).detach() # Tensor(batch size row,1), fill 1.0 mark Real Image
fake = Tensor(image.size(0), 1).fill_(0.0).detach() # .detach() mean with no grad / requires_grad = False
- 先训练Generator
第一步:先用Tensor生成符合任意分布的噪声数据z,z . shape = (Batch Size, input_Dim)例如程序中将噪声输入数据维度可设置为100。
第二步:生成数据G(z),并对G(z)添加标签,目的是为了愚弄判别器。
第三步:生成损失值,这里损失值来源于判别器D(G(z))
第四步:BP & 更新步骤
optimizer_G.zero_grad() # 对已有的gradient清零,因为有新来的Batch
# 第一步
z = Tensor(np.random.normal(0, 1, (batch_size, input_Dim)))
# 第二步
gen_imgs = generator(z)
# 第三步
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
# 第四步
g_loss.backward(