class Center_loss(nn.Module):
def __init__(self,src_class):
super(Center_loss, self).__init__()
self.n_class=src_class
self.MSELoss = nn.MSELoss()
self.MSELoss = self.MSELoss.cuda()
def forward(self, feature,label,domain_index):
domian_index_1 = (domain_index == 0)
domian_index_2 = (domain_index == 1)
domian_index_3 = (domain_index == 2)
label1 = label[domian_index_1]
label2 = label[domian_index_2]
label3 = label[domian_index_3]
feature1 = feature[domian_index_1]
feature2 = feature[domian_index_2]
feature3 = feature[domian_index_3]
s1, d = feature1.shape
s2, d = feature2.shape
s3, d = feature3.shape
ones1 = t.ones_like(label1, dtype=t.float)
ones2 = t.ones_like(label2, dtype=t.float)
ones3 = t.ones_like(label3, dtype=t.float)
zeros = t.zeros(self.n_class)
zeros = zeros.cuda()
n_classes1 = zeros.scatter_add(0, label1, ones1)
n_classes2 = zeros.scatter_add(0, label2, ones2)
n_classes3 = zeros.scatter_add(0, label3, ones3)
s_ones1 = t.ones_like(n_classes1)
s_ones2 = t.ones_like(n_classes2)
s_ones3 = t.ones_like(n_classes3)
n_classes1 = t.max(n_classes1, s_ones1)
n_classes2 = t.max(n_classes2, s_ones2)
n_classes3 = t.max(n_classes3, s_ones3)
zeros = t.zeros(self.n_class, d)
zeros = zeros.cuda()
s_sum_feature1 = zeros.scatter_add(0, t.transpose(label1.repeat(d, 1), 1, 0), feature1)
s_sum_feature2 = zeros.scatter_add(0, t.transpose(label2.repeat(d, 1), 1, 0), feature2)
s_sum_feature3 = zeros.scatter_add(0, t.transpose(label3.repeat(d, 1), 1, 0), feature3)
current_s_centroid1 = t.div(s_sum_feature1, n_classes1.view(self.n_class, 1))
current_s_centroid2 = t.div(s_sum_feature2, n_classes2.view(self.n_class, 1))
current_s_centroid3 = t.div(s_sum_feature3, n_classes3.view(self.n_class, 1))
semantic_loss = self.MSELoss(current_s_centroid1, current_s_centroid2) + \
self.MSELoss(current_s_centroid1, current_s_centroid3) + \
self.MSELoss(current_s_centroid2, current_s_centroid3)
return semantic_loss