2020.1.14 Updates
原来 tensorflow 自带一个 triplet loss 实现,支持单标签数据,见 [3]。可以参考写法。
Triplet Loss
triplet loss 的形式:
L
=
max
{
0
,
d
(
x
a
,
x
p
)
+
α
−
d
(
x
a
,
x
n
)
}
L=\max\{0, d(x_a,x_p)+\alpha-d(x_a,x_n)\}
L=max{0,d(xa,xp)+α−d(xa,xn)}
其中
d
(
⋅
,
⋅
)
d(\cdot,\cdot)
d(⋅,⋅) 表示一种距离,如欧氏距离。
涉及三个样本,其中 anchor 样本
x
a
x_a
xa 与 positive 样本
x
p
x_p
xp 相似(有共同 label),与 negative 样本
x
n
x_n
xn 不相似(无共同 label);而且要求它们是三个不同的样本。
目标效果是使每个样本(的 embedding)与不相似样本的距离至少比与相似样本的距离大一个 margin
α
\alpha
α,即
d
(
x
a
,
x
n
)
≥
d
(
x
a
,
x
p
)
+
α
d(x_a,x_n)\geq d(x_a,x_p)+\alpha
d(xa,xn)≥d(xa,xp)+α。
Sampling
计算此损失需要采样若个
(
x
a
,
x
p
,
x
n
)
(x_a,x_p,x_n)
(xa,xp,xn)。
对于某个 anchor,positive 和 negative sample 的采样可以 off-line 也可以 on-line。
这里是在一个 batch 里 on-line 地采样,可以用到样本最新的 embedding。
Mask & Distance
大思路是用 mask:把 batch 内所有 n3 个距离算出来,然后用 mask 筛出需要的部分。
需要用到两种 mask:
- 关于 index 的,因为 ( x a , x p , x n ) (x_a,x_p,x_n) (xa,xp,xn) 中三个样本要是不同的样本;
- 关于 label 的,因为 anchor 要和 positive 相似、和 negative 不相似;
然后两个 mask 取交集。
下面假设 label
是传入的一个 batch 的 label 的 one-/multi-hot 向量,是 tf 的 tensor。
index mask
求一个三阶张量 M,使得
M
i
,
j
,
k
=
1
M_{i,j,k}=1
Mi,j,k=1 当且仅当
i
,
j
,
k
i,j,k
i,j,k 各不相等。
单位阵
I
I
I 表示 index 相同的集合,即
I
[
i
]
[
j
]
=
1
I[i][j] = 1
I[i][j]=1 当且仅当
i
=
j
i=j
i=j;而
I
ˉ
=
1
−
I
\bar I=1-I
Iˉ=1−I 就相反。
I
ˉ
\bar I
Iˉ 在沿 axis = 0 处升维,得到 A,
A
[
⋅
]
[
i
]
[
j
]
=
1
A[\cdot][i][j]=1
A[⋅][i][j]=1 当且仅当
i
≠
j
i\neq j
i=j(即第 2、3 维下标不等);类似地沿 axis = 1,2 处升维得到
B
[
i
]
[
⋅
]
[
j
]
B[i][\cdot][j]
B[i][⋅][j]、
C
[
i
]
[
j
]
[
⋅
]
C[i][j][\cdot]
C[i][j][⋅],三者取交就得到 M。
# import tensorflow as tf
def index_mask(label):
batch_size = tf.shape(label)[0]
I = tf.cast(tf.eye(batch_size), tf.bool) # 单位阵 I
I_bar = tf.logical_not(I) # 1 - I
A = tf.expand_dims(I_bar, 0) # 2, 3 维不等
B = tf.expand_dims(I_bar, 1) # 1, 3 维不等
C = tf.expand_dims(I_bar, 2) # 1, 2 维不等
M = tf.logical_and(tf.logical_and(A, B), C) # 三者同时成立
return M
similarity mask
求一个三阶张量 M,使得
M
i
,
j
,
k
=
1
M_{i,j,k}=1
Mi,j,k=1 当且仅当
x
i
x_i
xi 和
x
j
x_j
xj 相似,而和
x
k
x_k
xk 不相似。
先求个相似矩阵 S,
S
i
,
j
=
1
S_{i,j}=1
Si,j=1 当且仅当
x
i
x_i
xi 和
x
j
x_j
xj 相似。这可以由 label 矩阵算出来,单标签、多标签都行。
类似上面,
S
ˉ
=
1
−
S
\bar S=1-S
Sˉ=1−S,然后 S 沿 axis = 2 升维到
E
[
i
]
[
j
]
[
⋅
]
E[i][j][\cdot]
E[i][j][⋅],
S
ˉ
\bar S
Sˉ 沿 axis = 1 升维到
F
[
i
]
[
⋅
]
[
j
]
F[i][\cdot][j]
F[i][⋅][j],两者取交。
# import tensorflow as tf
def similarity_mask(label):
S = tf.matmul(label, tf.transpose(label)) > 0 # 相似矩阵 S
S_bar = tf.logical_not(S) # 1 - S
E = tf.expand_dims(S, 2) # 1, 2 维相似
F = tf.expand_dims(S_bar, 1) # 1, 3 维不相似
M = tf.logical_and(E, F) # 两者者同时成立
return M
distance
求三阶张量 L,使得
L
i
,
j
,
k
=
max
{
0
,
d
(
i
,
j
)
−
d
(
i
,
k
)
+
α
}
L_{i,j,k}=\max\{0,d(i,j)-d(i,k)+\alpha\}
Li,j,k=max{0,d(i,j)−d(i,k)+α},即加 mask 前的 triplet loss。
距离自选。算出样本两两之间的距离矩阵 D 之后,沿 axis = 1,2 升维得到
N
[
i
]
[
⋅
]
[
j
]
N[i][\cdot][j]
N[i][⋅][j] 和
P
[
i
]
[
j
]
[
⋅
]
P[i][j][\cdot]
P[i][j][⋅]。可以将 P 当成 anchor 和 positive 的距离、N 当成 anchor 和 negative 的距离,于是
L
=
m
a
x
{
0
,
P
−
N
+
α
}
L=max\{0,P-N+\alpha\}
L=max{0,P−N+α}。
Code
对应 [1] 中的 batch all 策略
# import tensorflow as tf
def index_mask(label):
batch_size = tf.shape(label)[0]
I = tf.cast(tf.eye(batch_size), tf.bool) # 单位阵 I
I_bar = tf.logical_not(I) # 1 - I
A = tf.expand_dims(I_bar, 0) # 2, 3 维不等
B = tf.expand_dims(I_bar, 1) # 1, 3 维不等
C = tf.expand_dims(I_bar, 2) # 1, 2 维不等
M = tf.logical_and(tf.logical_and(A, B), C) # 三者同时成立
return M
def similarity_mask(label):
S = tf.matmul(label, tf.transpose(label)) > 0 # 相似矩阵 S
S_bar = tf.logical_not(S) # 1 - S
E = tf.expand_dims(S, 2) # 1, 2 维相似
F = tf.expand_dims(S_bar, 1) # 1, 3 维不相似
M = tf.logical_and(E, F) # 两者者同时成立
return M
def distance(x):
"""某种距离"""
return
def triplet_loss(x, label, alpha):
D = distance(x)
P = tf.expand_dims(D, 2) # d(a,p)
N = tf.expand_dims(D, 1) # d(a,n)
L = tf.maximum(0.0, P - N + alpha)
Mi = index_mask(label)
Ms = similarity_mask(label)
M = tf.logical_and(Mi, Ms)
triplet_loss = tf.multiply(L, M) # 筛选
# 算平均
valid_triplets = tf.to_float(tf.greater(triplet_loss, 1e-16))
num_positive_triplets = tf.reduce_sum(valid_triplets)
mean_triplet_loss = tf.reduce_sum(triplet_loss) / (num_positive_triplets + 1e-16)
return mean_triplet_loss