网络结构
论文中的网络结构图如下,embedding的提取直接使用预训练好的text encoder进行提取(不是本文重点)。提出的StackGAN整个模型包含2个GAN网络,分别用于两个阶段:
Stage1 :embedding+ noise 为输入,利用GAN输出低分辨率的64x64大小的影像;
Stage2 :embedding+ Stage I的低分辨率生成影像 为输入,利用GAN输高分辨率的256x256大小的影像
结合代码,stage I与stage II 的详细结构如下:
注意:其实代码中stage II 鉴别器输出的logit 有两种,分为condition 和uncondition,分别对应着有无引入embedding信息。(图中只显示了condition的logit输出)
每个阶段的GAN训练流程是相同的:
- 生成fake img;
- 训练鉴别器。考虑三种鉴别器输入,(1)real pairs:真实图像与对应的文本embedding,gt为 1;(2)wrong pairs:真实图像与不匹配的文本embedding,gt为 0;(3) fake pairs:生成