我们进行多数据集训练时,引入了掩码向量mask vector,我们引入一个掩码向量m,允许StarGAN忽略未知的标签。
在用多个数据集训练时把mask向量添加到生成器中,生成器G忽略未指定的标签(零向量),并关注给定的标签。除了输入标签的维度外,生成器的结构与单个数据集的训练完全相同。
因为训练涉及两个数据集,所以在训练时,判别器每次只针对当前已知的标签,来最小化分类误差。例如,当训练是基于CelebA时,判别器最小化的目标只是与CelebA属性相关的分类误差。通过在CelebA和Fer2013之间交替变换,判别器学到了两个数据集上的所有特征。
def label2onehot(self, labels, dim):
"""将标签索引转换为一个one-hot量。"""
batch_size = labels.size(0)
out = torch.zeros(batch_size,