StarGAN多数据集训练

我们进行多数据集训练时,引入了掩码向量mask vector,我们引入一个掩码向量m,允许StarGAN忽略未知的标签。
在这里插入图片描述
在用多个数据集训练时把mask向量添加到生成器中,生成器G忽略未指定的标签(零向量),并关注给定的标签。除了输入标签的维度外,生成器的结构与单个数据集的训练完全相同。
因为训练涉及两个数据集,所以在训练时,判别器每次只针对当前已知的标签,来最小化分类误差。例如,当训练是基于CelebA时,判别器最小化的目标只是与CelebA属性相关的分类误差。通过在CelebA和Fer2013之间交替变换,判别器学到了两个数据集上的所有特征。
在这里插入图片描述

def label2onehot(self, labels, dim):
        """将标签索引转换为一个one-hot量。"""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size,
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值