基于中心对齐的领域泛化损失

class Center_loss(nn.Module):
    def __init__(self,src_class):
        super(Center_loss, self).__init__()

        self.n_class=src_class
        self.MSELoss = nn.MSELoss()  # (x-y)^2
        self.MSELoss = self.MSELoss.cuda()




    def forward(self, feature,label,domain_index):
        domian_index_1 = (domain_index == 0)
        domian_index_2 = (domain_index == 1)
        domian_index_3 = (domain_index == 2)

        label1 = label[domian_index_1]
        label2 = label[domian_index_2]
        label3 = label[domian_index_3]

        feature1 = feature[domian_index_1]
        feature2 = feature[domian_index_2]
        feature3 = feature[domian_index_3]

        s1, d = feature1.shape
        s2, d = feature2.shape
        s3, d = feature3.shape

        ones1 = t.ones_like(label1, dtype=t.float)
        ones2 = t.ones_like(label2, dtype=t.float)
        ones3 = t.ones_like(label3, dtype=t.float)

        zeros = t.zeros(self.n_class)

        zeros = zeros.cuda()

        n_classes1 = zeros.scatter_add(0, label1, ones1)
        n_classes2 = zeros.scatter_add(0, label2, ones2)
        n_classes3 = zeros.scatter_add(0, label3, ones3)

        # image number cannot be 0, when calculating centroids
        s_ones1 = t.ones_like(n_classes1)
        s_ones2 = t.ones_like(n_classes2)
        s_ones3 = t.ones_like(n_classes3)

        n_classes1 = t.max(n_classes1, s_ones1)
        n_classes2 = t.max(n_classes2, s_ones2)
        n_classes3 = t.max(n_classes3, s_ones3)

        # calculating centroids, sum and divide
        zeros = t.zeros(self.n_class, d)

        zeros = zeros.cuda()

        s_sum_feature1 = zeros.scatter_add(0, t.transpose(label1.repeat(d, 1), 1, 0), feature1)
        s_sum_feature2 = zeros.scatter_add(0, t.transpose(label2.repeat(d, 1), 1, 0), feature2)
        s_sum_feature3 = zeros.scatter_add(0, t.transpose(label3.repeat(d, 1), 1, 0), feature3)

        current_s_centroid1 = t.div(s_sum_feature1, n_classes1.view(self.n_class, 1))
        current_s_centroid2 = t.div(s_sum_feature2, n_classes2.view(self.n_class, 1))
        current_s_centroid3 = t.div(s_sum_feature3, n_classes3.view(self.n_class, 1))

        semantic_loss = self.MSELoss(current_s_centroid1, current_s_centroid2) + \
                        self.MSELoss(current_s_centroid1, current_s_centroid3) + \
                        self.MSELoss(current_s_centroid2, current_s_centroid3)


        return semantic_loss


   
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值