Mutual Wasserstein Discrepancy Minimization for Sequential Recommendation
论文连接:Mutual Wasserstein Discrepancy Minimization for Sequential Recommendation
1 动机
基于自监督的序列推荐通过设计的数据增强方法以及最大化互信息,显著低提升推荐系统的性能。然而基于KL散度的互信息估计方法存在一些限制。同时现在有数据增强方法大多都是随机的,这种方法可能会破坏序列的相关性。
针对上述两个问题,在本工作中作者提出了一个基于Mutual Wasserstein discrepancy minimization(Mstein)的用于序列推荐的自监督学习框架。具体来说使用Wasserstein Discrepancy Measurement 来衡量增强序列之间的互信息。作者在四个数据集上的实验结果表明, Mstein对数据扰动具有鲁棒性的同时提升了序列推荐性能。
因此本文主要contribution归纳如下:
- 提出Wasserstein Discrepancy Measurement替换KL散度作为互信息的对比下界,以缓解使用KL散度的InfoNCE对比损失的局限性。
- 提出了mutual Wasserstein discrepancy minimization的对比损失函数,并验证了其在建模增强的随机性和对扰动的稳健表示学习方面的优越性。
- 论证了在Wasserstein Discrepancy Measurement,在对齐性和均匀性得到了精确到优化。
- 在四个数据集上进行相关的实验。
2 方法
本文所提出的Mstein框架是基于S^3Rec和STOSA模型框架上实现的。因此作者首先介绍了STOSA推荐模型,模型结构如下图所示:
2.1 Stochastic Transformer for SR
为了对序列推荐中不确定的信息进行建模,STOSA在Transformer的基础上引入随机嵌入和Wassertein self-attention模块。在STOSA中将物品表示为高斯分布,其随机嵌入(stochastic embeddings)表述如下:
E
μ
=
E
m
b
μ
(
S
u
)
,
E
∑
=
E
m
b
∑
(
S
u
)
\mathbf{E}^\mu = {\rm Emb}_{\mu}(S^u),\mathbf{E^{\sum}}={\rm Emb}_{\sum}(S^u)
Eμ=Embμ(Su),E∑=Emb∑(Su)
对于行为序列
S
u
\mathcal S^u
Su中的物品
v
i
v_i
vi,STOSA方法使用了高斯分布来建模物品的随机嵌入,其中每个物品的embedding由均值
E
μ
\mathbf{E}^{\mu}
Eμ和方差
d
i
a
g
(
E
K
∑
)
{\rm diag}(\mathbf{E}^{\sum}_K)
diag(EK∑)来参数化。Wasserstein自注意力采用了负的2-Wasserstein距离,来计算特定一对物品(𝑣𝑖 , 𝑣𝑗 )之间的自注意力值,如下所示:
A
i
j
=
−
(
W
2
(
v
i
,
v
j
)
)
=
−
(
∣
∣
μ
v
i
−
μ
v
j
∣
∣
2
2
+
∣
∣
∑
v
i
1
/
2
−
∑
v
j
1
/
2
∣
∣
F
2
)
{\rm A}_{ij}=-(W_2(v_i,v_j))=-\left(||\mu_{v_i} - \mu_{v_j}||^2_2+||\sum^{1/2}_{v_i}-\sum^{1/2}_{v_j}||^2_F\right)
Aij=−(W2(vi,vj))=−
∣∣μvi−μvj∣∣22+∣∣vi∑1/2−vj∑1/2∣∣F2
其中
W
2
(
⋅
)
W_2(\cdot)
W2(⋅)表示2Wasserstein距离,
μ
u
i
=
E
v
i
μ
W
K
μ
\mu_{ui}={\rm E}^{\mu}_{v_i}{\rm W}^{\mu}_K
μui=EviμWKμ,
∑
v
i
=
E
L
U
(
d
i
a
g
(
E
^
u
i
∑
W
K
∑
)
)
+
1
\sum_{vi}={\rm ELU\left({\rm diag}(\hat{E}^{\sum}_{ui}W^{\sum}_K)\right)}+1
∑vi=ELU(diag(E^ui∑WK∑))+1。STOSA模型与transformer类似由前馈神经网络、残差连接以及layer normalization模块,因此STOAS的序列编码表示如下:
h
u
=
(
h
u
μ
,
h
u
∑
)
=
S
t
o
s
a
E
n
c
(
S
u
)
\mathbf{h_u}=(\mathbf{h}^{\mu}_u,\mathbf{h}^{\sum}_u)={\rm StosaEnc}(\mathcal S^u)
hu=(huμ,hu∑)=StosaEnc(Su)
其中
h
u
μ
\mathbf{h}^{\mu}_u
huμ和
h
u
∑
\mathbf{h}^{\sum}_u
hu∑表示
S
u
\mathcal S^u
Su的随机序列嵌入。对于时间步
t
t
t,
h
u
,
t
=
(
h
u
,
t
μ
h
u
,
t
∑
)
\mathbf{h}_{u,t}=(\mathbf{h}^{\mu}_{u,t}\mathbf{h}^{\sum}_{u,t})
hu,t=(hu,tμhu,t∑)编码next-item的表征。STOSA的损失函数表示如下:
L
r
e
c
=
∑
S
u
∈
S
∑
t
=
1
∣
S
u
∣
−
l
o
g
(
σ
(
W
2
(
h
u
,
t
,
v
t
−
)
−
W
2
(
h
u
,
t
,
v
t
+
)
)
)
+
λ
l
p
u
n
\mathcal L_{rec}=\sum_{\mathcal S^u \in \mathcal S}\sum_{t=1}^{|\mathcal S^u|}-{\rm log}(\sigma(W_2({\rm h}_{u,t},v^-_t)-W_2({\rm h}_{u,t},v^+_t)))+\lambda \mathcal l_{pun}
Lrec=Su∈S∑t=1∑∣Su∣−log(σ(W2(hu,t,vt−)−W2(hu,t,vt+)))+λlpun
其中
v
t
+
v^+_t
vt+是下一个物品的随机嵌入的ground truth,
v
t
−
v^-_t
vt−表示负采样物品的嵌入,
l
p
u
n
l_{pun}
lpun是STOSA中提出的损失函数。
2.2 对比损失里的InfoNCE
给定一批包含𝑁个用户序列的数据,随机数据增强会生成每个序列的两个扰动视图,从而得出在InfoNCE计算中有2
N
N
N个序列、
N
N
N个正样本对,以及
4
N
2
−
2
N
4N^2-2N
4N2−2N个负样本对。(
4
N
2
−
2
N
4N^2-2N
4N2−2N个负样本对是怎么算出来的?)对于包含
N
N
N个用户序列的批次
B
\mathcal B
B,增强后的样本对集合
S
B
S_{\mathcal B}
SB 如下所示:
S
B
=
{
S
a
u
1
,
S
b
u
1
,
…
,
S
a
u
N
,
S
b
u
N
}
S_{\mathcal B}=\{{\mathcal S^{u_1}_a,\mathcal S^{u_1}_b},\ldots,{\mathcal S^{u_N}_a,\mathcal S^{u_N}_b} \}
SB={Sau1,Sbu1,…,SauN,SbuN}
其中下标a和b表示
S
u
\mathcal S^u
Su的两种扰动视图。给定一个增强的序列对
(
S
a
u
i
,
S
b
u
i
)
(\mathcal S^{ui}_a,\mathcal S^{ui}_b)
(Saui,Sbui),InfoNCE计算如下:
L
c
l
(
h
a
u
i
,
h
b
u
i
)
=
−
l
o
g
e
x
p
(
s
i
m
(
h
a
u
i
,
h
b
u
i
)
)
e
x
p
(
s
i
m
(
h
a
u
i
,
h
b
u
i
)
)
+
∑
j
∈
S
B
−
e
x
p
(
s
i
m
(
h
a
u
i
,
h
j
)
)
\mathcal L_{cl}(\mathbf{h}^{ui}_a,\mathbf{h}^{ui}_b) = -{\rm log}\frac{{\rm exp}({\rm sim}(\mathbf{h}^{ui}_a,\mathbf{h}^{ui}_b))}{{\rm exp}({\rm sim}(\mathbf{h}^{ui}_a,\mathbf{h}^{ui}_b)) + \sum_{j \in S^-_{\mathcal B}}{\rm exp}({\rm sim}(\mathbf h^{ui}_a,\mathbf h^j))}
Lcl(haui,hbui)=−logexp(sim(haui,hbui))+∑j∈SB−exp(sim(haui,hj))exp(sim(haui,hbui))
其中
h
a
u
i
\mathbf h^{ui}_a
haui和
h
b
u
i
\mathbf h^{ui}_b
hbui是从编码器中学到的两种扰动视图的embedding。
S
B
−
=
S
B
−
{
S
a
u
i
,
S
b
u
i
}
S^-_{\mathcal B}=S_ \mathcal {B^{-}}\{\mathcal S^{ui}_a,\mathcal S^{ui}_b\}
SB−=SB−{Saui,Sbui}表示负样本的增强序列对。(正样本和负样本都需要增强操作?)
3 Wasserstein Discrepancy Measurement
3.1 InfoNCE and Mutual Information
对于给定用户 u i ui ui随机增强序列 ( x u i a = S u i a , x b u i = S b u i ) (x^ui_a=\mathcal S^ui_a,x^{ui}_b=\mathcal S^{ui}_b) (xuia=Suia,xbui=Sbui), ( x a u i , x b u i ) (x^{ui}_a,x^{ui}_b) (xaui,xbui)是服从随机增强分布的随机变量。InfoNCE和 ( x a u i , x b u i ) (x^{ui}_a,x^{ui}_b) (xaui,xbui)的互信息可以被描述为:
I
(
x
,
y
)
I(x,y)
I(x,y)表示随机变量
x
x
x和
y
y
y之间的互信息。上述公式(7)说明了优化
L
c
l
\mathcal L_{cl}
Lcl的同时最大化了互信息
I
(
x
a
u
i
,
x
b
u
i
)
>
l
o
g
(
2
N
−
1
)
−
L
c
l
I(x^{ui}_a,x^{ui}_b)>{\rm log}(2N-1)-\mathcal L_{cl}
I(xaui,xbui)>log(2N−1)−Lcl,公式(7)同时表明当batch size N越大,互信息可以更好的被优化。
3.2 Limitation of KL Divergence
KL散度作为互信息的对比下界存在的问题包括:对样本量的指数级需求以及训练的不稳定。现有的研究表明,基于KL散度的互信息估计在具有 N N N个样本上的高置信度下界,这个这个下界不超过 O ( l n ( N ) ) O(ln(N)) O(ln(N))。而在对比学习InfoNCE上同样有这样的限制。
给定两个增强序列分布:
p
(
x
a
)
p(x_a)
p(xa)和
p
(
x
b
)
p(x_b)
p(xb),
A
,
B
A,B
A,B表示以N从分布
p
(
x
a
)
、
p
(
x
b
)
p(x_a)、p(x_b)
p(xa)、p(xb)采样得到的增强序列集和。给定一个real-value的映射函数,以及置信参数
δ
\delta
δ,在概率
1
−
δ
1-\delta
1−δ的情况下可得:
D
K
L
(
P
(
x
a
)
,
p
(
x
b
)
)
≥
F
(
A
,
B
,
δ
)
D_{KL}(P(x_a),p(x_b))\ge F(A,B,\delta)
DKL(P(xa),p(xb))≥F(A,B,δ)
至少有概率
1
−
4
δ
1-4\delta
1−4δ:
l
n
N
≥
F
(
A
,
B
,
δ
)
{\rm ln}N \ge F(A,B,\delta)
lnN≥F(A,B,δ)
因此可以得到互信息的边界是
N
=
e
x
p
(
I
(
x
a
,
x
b
)
)
N={\rm exp}(I(x_a,x_b))
N=exp(I(xa,xb))
4 优化和预测
L
=
L
r
e
c
+
β
L
M
S
t
e
i
n
\mathcal L = \mathcal L_{rec}+\beta \mathcal L_{\rm MStein}
L=Lrec+βLMStein
其中,
β
\beta
β是用于调整对比损失和互信息Wasserstein差异最小化贡献的超参数。最终的推荐列表是通过计算序列编码分布嵌入
(
h
u
μ
,
h
u
∑
)
(\mathbf h^{\mu}_u,\mathbf h^{\sum}_u)
(huμ,hu∑)与所有物品的随机嵌入之间的Wasserstein距离来生成的。对所有物品的距离按升序排序,以生成前N个物品。
5 实验结果(部分)
在文章中作者进行了不同的实验,验证了模型的性能。
6 总结
作者研究了互信息和InfoNCE之间的关系,并讨论了基于KL散度的互信息估计的局限性,包括不对称估计、对样本量的指数级需求和训练不稳定性。我们提出了一种基于Wasserstein距离的替代互信息估计选择,称为Wasserstein Discrepancy Measurement。通过提出的Wasserstein Discrepancy Measurement,我们在InfoNCE框架中制定了互Wasserstein差异最小化的方法,称为MStein。在四个基准数据集上进行的大量实验表明,使用Wasserstein Discrepancy Measurement的MStein在互信息估计中具有更高的性能。额外的稳健性分析证明了MStein对于嘈杂的交互和不同数据大小的变化更加稳健。