《Probabilistic label trees for extreme multi-label classification》
核心思想:根据XC的树状层次结构,将所有训练样本赋给树中的所有结点,并判断样本
x
\mathbf{x}
x在结点
v
v
v上是正例还是负例。
判断依据是,样本
x
\mathbf{x}
x在结点
v
v
v的所有叶子结点上的标签是否包含1,如果包含,那么其在
v
v
v上就是正例,否则为负例。
这样每个结点就有了一个训练样本集(只有正例和负例),然后为每一个结点训练一个binary分类器。
问题定义
符号系统:
Key notations | Meaning |
---|---|
X \mathcal{X} X | instance space |
L = { 1 , … , m } \mathcal{L} = \{1, \dots, m\} L={1,…,m} | Label set |
Y = { 0 , 1 } m \mathcal{Y} = \{0, 1\}^m Y={0,1}m | Label space |
x ∈ X \mathbf{x} \in \mathcal{X} x∈X | an instance |
y ∈ Y \mathbf{y} \in \mathcal{Y} y∈Y | a label corresponding to x \mathbf{x} x |
L x ⊆ L \mathcal{L}_\mathbf{x} \subseteq \mathcal{L} Lx⊆L | relevant(positive) labels, otherwise irrelevant(positive) labels. y j = 1 ⇔ j ∈ L x y_j = 1 \Leftrightarrow j \in \mathcal{L}_\mathbf{x} yj=1⇔j∈Lx |
R ( ⋅ ) R(\cdot) R(⋅) | The expected loss, or risk |
P ( x , y ) \mathbf{P}(\mathbf{x},\mathbf{y}) P(x,y) | 观测 ( x , y ) (\mathbf{x},\mathbf{y}) (x,y)的概率分布, 假定每个观测独立采样 |
ℓ ( y , y ^ ) \ell(\mathbf{y},\hat{\mathbf{y}}) ℓ(y,y^) | Loss |
T T T | The tree |
L T L_T LT | leaf set; l j ∈ L T l_j \in L_T lj∈LT对应 j ∈ L j \in \mathcal{L} j∈L |
V T V_T VT | the set of all nodes |
L v ⊆ L T L_v \subseteq L_T Lv⊆LT | 内节点 v v v的所有叶子 |
L v ⊆ L \mathcal{L}_v \subseteq \mathcal{L} Lv⊆L | 内节点 v v v对应的所有叶子的标签集合 |
↑ ( v ) , ↓ ( v ) \uparrow(v), \downarrow(v) ↑(v),↓(v) | 父节点,直接孩子节点集合 |
Path ( v ) \text{Path}(v) Path(v) | 从 v v v到根节点的路径 |
len v \text{len}_v lenv | 路径长度 |
deg v \text{deg}_v degv | 节点 v v v的度 |
本文作者的问题定义写的很好,读起来很通畅。先前也看了一些XC的文章,都没有将问题定义描述的很好(或者压根没有问题定义)。
极限多标签分类问题可定义为(类似于多标签分类问题的定义):寻找一个分类器
h
(
x
)
=
(
h
1
(
x
)
,
…
,
h
m
(
x
)
)
∈
H
m
:
X
↦
R
m
\mathbf{h}(\mathbf{x}) = (h_1(\mathbf{x}),\dots,h_m(\mathbf{x})) \in \mathcal{H}^m:\mathcal{X}\mapsto \mathbb{R}^m
h(x)=(h1(x),…,hm(x))∈Hm:X↦Rm,使得期望损失极小:
R
ℓ
(
h
)
=
E
(
x
,
y
)
∼
P
(
x
,
y
)
(
ℓ
(
y
,
h
(
x
)
)
)
R_\ell(\mathbf{h}) = \mathbb{E}_{(\mathbf{x}, \mathbf{y}) \sim \mathbf{P}(\mathbf{x},\mathbf{y})}(\ell(\mathbf{y},\mathbf{h}(\mathbf{x})))
Rℓ(h)=E(x,y)∼P(x,y)(ℓ(y,h(x)))
一般地,
m
≥
1
0
5
,
∣
L
x
∣
≪
m
m\geq 10^5,|\mathcal{L}_\mathbf{x}| \ll m
m≥105,∣Lx∣≪m。那么在损失
ℓ
\ell
ℓ上的最优分类器为:
h
ℓ
∗
=
arg min
h
R
ℓ
(
h
)
\mathbf{h}_\ell^* = \argmin_{\mathbf{h}} R_\ell(\mathbf{h})
hℓ∗=hargminRℓ(h)
文中定义了一个分类器
h
\mathbf{h}
h针对损失
ℓ
\ell
ℓ的遗憾(regret):
reg
ℓ
(
h
)
=
R
ℓ
(
h
)
−
R
ℓ
(
h
ℓ
∗
)
=
R
ℓ
(
h
)
−
R
ℓ
∗
\text{reg}_\ell(\mathbf{h}) = R_\ell(\mathbf{h}) - R_\ell(\mathbf{h}_\ell^*) = R_\ell(\mathbf{h}) - R_\ell^*
regℓ(h)=Rℓ(h)−Rℓ(hℓ∗)=Rℓ(h)−Rℓ∗
当然它越小越好。
模型希望
L
1
L_1
L1估计误差最小:
∣
P
(
y
j
=
1
∣
x
)
−
P
^
(
y
j
=
1
∣
x
)
∣
|P(y_j=1|\mathbf{x}) - \hat{P}(y_j=1|\mathbf{x})|
∣P(yj=1∣x)−P^(yj=1∣x)∣
令
ℓ
log
\ell_\text{log}
ℓlog为交叉熵损失,其在样本
x
\mathbf{x}
x上的条件风险(也就是期望损失)为:
E
y
ℓ
log
(
y
,
h
(
x
)
)
=
∑
j
=
1
m
R
log
(
h
j
(
x
)
∣
x
)
\mathbb{E}_\mathbf{y}\ell_{\text{log}}(\mathbf{y},\mathbf{h}(\mathbf{x})) = \sum_{j=1}^m R_\text{log}(h_j(\mathbf{x})|\mathbf{x})
Eyℓlog(y,h(x))=j=1∑mRlog(hj(x)∣x)
那么最优预测为
h
j
∗
(
x
)
=
arg min
h
R
log
(
h
j
(
x
)
∣
x
)
h_j^*(\mathbf{x}) = \argmin_\mathbf{h}R_\text{log}(h_j(\mathbf{x})|\mathbf{x})
hj∗(x)=hargminRlog(hj(x)∣x)
当然,交叉熵损失函数实际上只对应一般的(文章中用了一个似乎比较地道的词:vanilla)1-vs-all方法。
而更加流行的评价指标就有
P
@
k
,
n
D
C
G
@
k
,
P
S
P
@
k
P@k,nDCG@k,PSP@k
P@k,nDCG@k,PSP@k等,也就是人们通常只关心top-k。
PLT model
什么是标签树?如果标签结构有层次关系,那么标签树可以自然导出(不一定是如下图所示的二叉结构)。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0jUaOGeR-1663573740332)(images/006.png)]
文章定义了一个
z
\mathbf{z}
z向量。对任意一个
y
∈
Y
\mathbf{y} \in \mathcal{Y}
y∈Y,
z
=
{
z
v
1
,
z
v
2
,
…
}
\mathbf{z} = \{z_{v_1}, z_{v_2}, \dots\}
z={zv1,zv2,…},其中的分量对应标签树中的节点,比如对节点
v
v
v:
z
v
=
1
if
∑
j
∈
L
v
y
j
≥
1
,
0
otherwise
.
z_{v} = 1 \text{ if } \sum_{j \in \mathcal{L}_v} y_j \geq 1, 0 \text{ otherwise}.
zv=1 if j∈Lv∑yj≥1,0 otherwise.
其中
L
v
\mathcal{L}_v
Lv为节点
v
v
v下的所有叶子节点标签集合。
根据链式法则,有
P
(
z
v
=
1
∣
x
)
=
∏
v
′
∈
Path
(
v
)
P
(
z
v
′
=
1
∣
z
↑
(
v
′
)
=
1
,
x
)
=
P
(
z
v
=
1
∣
z
↑
(
v
)
=
1
,
x
)
P
(
z
↑
(
v
)
=
1
∣
x
)
P(z_v = 1 | \mathbf{x}) = \prod_{v' \in \text{Path}(v)} P(z_{v'} = 1 | z_{\uparrow(v')} = 1, \mathbf{x}) = P(z_v = 1 | z_{\uparrow(v)}=1, \mathbf{x}) P(z_{\uparrow(v)} = 1 | \mathbf{x})
P(zv=1∣x)=v′∈Path(v)∏P(zv′=1∣z↑(v′)=1,x)=P(zv=1∣z↑(v)=1,x)P(z↑(v)=1∣x)
作者提出了一个Proposition: 对任意一个
T
,
P
(
y
∣
x
)
T, \mathbf{P}(\mathbf{y}|\mathbf{x})
T,P(y∣x),以及一个内节点
v
∈
V
T
∖
L
T
v\in V_T \setminus L_T
v∈VT∖LT,都有下式成立
∑
v
′
∈
↓
(
v
)
P
(
z
v
′
=
1
∣
z
v
=
1
,
x
)
≥
1
\sum_{v'\in \downarrow(v)} P(z_{v'}=1 | z_{v}=1, \mathbf{x}) \geq 1
v′∈↓(v)∑P(zv′=1∣zv=1,x)≥1
这个Proposition说明了:如果对于一个节点
v
v
v,它存在至少一个叶子上的标签为1,那么其子节点条件概率之和应该要大于等于1。 之所以是大于等于而不是等于,考虑是multi-label场景不存在mutual exclusive;如果在multi-class场景下,那么上式应该只能是等于符号。
且
P
(
z
v
=
1
∣
x
)
P(z_v = 1 | \mathbf{x})
P(zv=1∣x)满足
max
{
P
(
z
v
′
=
1
∣
x
)
:
v
′
∈
↓
(
v
)
}
≤
P
(
z
v
=
1
∣
x
)
≤
min
{
1
,
∑
v
′
∈
↓
(
v
)
P
(
z
v
′
=
1
∣
x
)
}
\max\{P(z_{v'}=1|\mathbf{x}) : v' \in \downarrow(v)\} \leq P(z_v = 1 | \mathbf{x}) \leq \min \{1, \sum_{v'\in \downarrow(v)} P(z_{v'}=1 | \mathbf{x})\}
max{P(zv′=1∣x):v′∈↓(v)}≤P(zv=1∣x)≤min{1,v′∈↓(v)∑P(zv′=1∣x)}
上式很好理解:任意一个节点的无条件概率应该总是大于其子节点的无条件概率(这被称为Hierarchical constraint,在层次多标签分类问题里面是一种常见的约束),且应该总是小于其子节点概率之和。
训练阶段:为所有树中的节点训练分类器。
这里Algorithm2表示,将一个样本赋给树中的每一个节点Negative/Positive。
意思就是,如果一个样本在某一个标签上的值为1,那么该样本在对应的节点和祖先节点上都是positive的,否则为negative。
Algorithm1就是将每一个样本按照positive/negative的方式赋给每一个节点,这样树中的每一个节点都有了一些正例和负例样本,进而为每一个节点训练binary分类器。
所以什么是PLT?就是一棵树,其中每个结点对应一个binary分类器用以预测一个样本在当前结点的概率。
算法部分看起来相当简单,但非常的novel!PLT将多标签分类问题转化成了在树结构结点上的若干二分类问题,在预测阶段就可以剪枝了!
预测阶段:很显然,要能搜索这颗标签树(每个结点的分类器得到测试样本的估计概率),以得到预测结果。
文中采用了两种策略,一种是根据给定的概率阈值决定是否剪枝(算法3第5行),树的搜索过程采用栈结构就可以了。
另外一种策略是寻找概率估计最大的top-
k
k
k标签(叶子结点),采用了优先队列这种数据结构,如算法4所示。
注意算法4中的
Q
\mathcal{Q}
Q是一个优先队列,每次会弹出概率估计最大的结点
v
v
v。
这两种预测策略看起来都很常规,but非常巧妙!充分利用了树状结构的特点。
虽然算法看起来很简单,但作者做出了充分的统计分析(没细看,看估计也看不懂),强行将文章拉了几个理论高度,相当厉害!