本周任务
1、进一步学习CGAN
2、在上周的基础上生成指定图像
CGAN模型
cGAN的中心思想是希望 可以控制 GAN 生成的图片,而不 是单纯的随机生成图片。 具体来说,Conditional GAN 在生成器和判别器的输入中 增加了额外的 条件信息,生成器生成的图片只有足够真实 且与条件相符,才能够通过判别器。
实际上 , 在无条件约束的生成模型中 , 没法控制数据生成的模式。然而,通过额外的信息对模型进行约束,有可能指导数据生成的过程。条件约束可以是类标签 , 可以是图像修补的部分数据, 甚至是来自不同模态的数据
cGAN将 无监督学习 转为 有监督学习 使得网络可以更好地在我们的掌控下进行学习!
从公式看,CGAN相当于在原始GAN的基础上对生成器部分 和判别器部分都加了一个条件
代码可以参考上一期的文章
深度学习-第G3周:CGAN|生成手势图像_quant_day的博客-CSDN博客
生成指定图片
#%%
from numpy.random import randint, randn
from numpy import linspace
from matplotlib import pyplot, gridspec
# 导入生成器模型
generator.load_state_dict(torch.load('generator_epoch_300.pth'), strict=False)
generator.eval()
interpolated = randn(100) # 生成两个潜在空间的点
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)
label = 0
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()
predictions = generator((interpolated, labels))
predictions = predictions.permute(0,2,3,1).detach().cpu()
import warnings
warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100
plt.figure(figsize=(8,3))
pred = (predictions[0, :, :, :] + 1) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show