1.目标
使用cleanlab进行置信学习,找出多标签分类任务中标注错误的数据。
2.代码示例
2.1 单标签
from cleanlab.pruning import get_noise_indices
import numpy as np
s = np.array([0,1,0,2,1,0,1]) #s是numpy array,但标签的原始标签集合以一个列表表示
psx = np.array([[0.1,0.1,0.9],
[0.1,0.9,0.0],
[0.5,0.0,0.5],
[0.6,0.1,0.1],
[0.1,0.8,0.1],
[0.9,0.1,0.0],
[0.1,0.9,0.9]])#交叉验证获得的概率矩阵
ordered_label_errors = get_noise_indices(
s=correctly_formatted_labels,
psx=psx,
sorted_index_method='normalized_margin', # Orders label errors
multi_label=True,
)
print(ordered_label_errors)
2.2 多标签
多标签时原始标签s会有所不同。可以是one-hot形式也可以是不是。
one-hot
from cleanlab.pruning import get_noise_indices
from cleanlab.util import onehot2int
import numpy as np
# Add this if you want to detect errors with less than 5 examples in the class
from cleanlab import pruning
pruning.MIN_NUM_PER_CLASS = 1
s = [[0,0,1],
[0,1,0],
[0,0,1],
[1,1,0],
[1,0,1],
[0,1,1],
[1,0,1]]
psx = np.array([[0.1,0.1,0.9],
[0.1,0.9,0.0],
[0.5,0.0,0.5],
[0.6,0.1,0.1],
[0.1,0.8,0.1],
[0.9,0.1,0.0],
[0.1,0.9,0.9]])
#讲one-hot进行格式转换
correctly_formatted_labels = onehot2int(np.array(s))
ordered_label_errors = get_noise_indices(
s=correctly_formatted_labels,
psx=psx,
sorted_index_method='normalized_margin', # Orders label errors
multi_label=True, #目前的cleanlab 0.1.1版本需要设置该参数
)
print(ordered_label_errors)
不是one-hot
from cleanlab.pruning import get_noise_indices
import numpy as np
# Add this if you want to detect errors with less than 5 examples in the class
from cleanlab import pruning
pruning.MIN_NUM_PER_CLASS = 1
s = [[2],
[1],
[0,2],
[0],
[1],
[2],
[1,2],
]
psx = np.array([
[0.1,0.1,0.9],
[0.1,0.9,0.0],
[0.5,0.0,0.5],
[0.6,0.1,0.1],
[0.1,0.8,0.1],
[0.9,0.1,0.0],
[0.1,0.9,0.9],
])
ordered_label_errors = get_noise_indices(
s=s,
psx=psx,
sorted_index_method='normalized_margin', # Orders label errors
multi_label=True,
)
print(ordered_label_errors)
3.参考内容
https://github.com/cgnorthcutt/cleanlab/issues/55
https://github.com/cgnorthcutt/cleanlab/issues/23