损失和代码来自paper:Do We Really Need to Access the Source Data? Source Hypothesis Transfer for Unsupervised Domain Adaptation; Jian Liang, Dapeng Hu, Jiashi Feng
可以理解为个体entropy loss,
可以理解为群体entropy loss(多样性损失)。
代码如下:
outputs_test = model(inputs_test)
# output.size (64, 9)
softmax_out = nn.Softmax(dim=1)(outputs_test)
# entropy loss
entropy_loss = torch.mean(Entropy(softmax_out, reduction='sum'))
# divergence loss
msoftmax = softmax_out.mean(dim=0)
gentropy_loss = torch.sum(-msoftmax * torch.log(msoftmax + args.epsilon))
entropy_loss -= gentropy_loss
分类任务,设output.size为(64,9)即batch-size 64,num-classes 9。
首先对每个样本,将classifier feature softmax (64,9),对每个样本计算entropy,加和后取期望值,得到。
再把各样本分类概率(64,9)按列求均值后,得到每个类别的平均预测概率(9,)因此每个概率综合不同样本概率(多样性)即整个batch的平均输出的熵。这里使用了一个小的常数 epsilon
来避免对零取对数。
注意:最终优化函数为降低个体熵损失而增加群体熵损失,意在对每个单独的样本给出确定性的预测,而且在所有样本上给出均衡的预测(增加多样性)。
熵的高低反映了随机变量的不确定性水平。增加熵以鼓励模型探索和多样性,减少熵以提高确定性和效率。