def count_integer_overlap(rec_labels, true_labels, maxlength=50527):
# much faster (measured with timeit:)
if rec_labels is not None:
test_label_acc = (
1
- (
torch.bincount(rec_labels.view(-1), minlength=maxlength)
- torch.bincount(true_labels[true_labels != -100].view(-1), minlength=maxlength)
)
.abs()
.sum()
/ 2
/ rec_labels.numel()
)
else:
test_label_acc = 0
return test_label_acc
假设 rec_labels
是一个张量 [0, 1, 1, 2, 3, 3]
,而 true_labels
是 [1, 1, 2, 3, -100, -100]
。其中 -100
表示缺失的标签。
-
计算标签的出现次数: 使用
torch.bincount
方法,我们得到以下结果:- 对于
rec_labels
,标签0
出现了1
次,标签1
出现了2
次,标签2
出现了1
次,标签3
出现了2
次,其他标签均未出现。 - 对于
true_labels
,标签1
出现了2
次,标签2
出现了1
次,标签3
出现了1
次。
- 对于
-
调整计数以排除缺失的标签: 在
true_labels
中,我们需要忽略-100
标签的计数。 -
计算绝对差异: 将两个张量的计数相减,并取绝对值。我们得到以下差异:
- 对于标签
0
,差异为1 - 0 = 1
- 对于标签
1
,差异为2 - 2 = 0
- 对于标签
2
,差异为1 - 1 = 0
- 对于标签
3
,差异为2 - 1 = 1
- 对于标签
-
求和并归一化: 将这些差异相加得到总的绝对差异,然后除以标签总数的两倍。在这个例子中,总的绝对差异为
1 + 0 + 0 + 1 = 2
,标签总数为6
,所以重叠度为2 / (2 * 6) = 1/6
。