论文《TAGNN:Target Attentive Graph Neural Networks for Session-based Recommendation》阅读

论文概况

今天给大家带来的论文是中国科学院谭铁牛老师及其团队成员在SR-GNN的基础上,发表在SIGIR 2020上的一篇短文,完成模型TAGNN。

Introduction

这篇论文在SR-GNN的基础上,沿用了门控图神经网络(Gated Graph Neural Networks,GGNN)模型,并加入了对预测目标敏感的embedding表示,下面进行介绍。

The Proposed Mothod: TAGNN

Learning Item Embedding

这里使用GGNN完成物品embedding在每个session graph上的传播,具体如下:

公式(1)基于 t − 1 t-1 t1 时刻(即上一跳,也即邻接节点)得到当前向量的latent embedding,通过 A s , i : \mathbf{A}_{s,i:} As,i: 完成邻接节点的聚合,这里数学表达可能稍有问题,需要将 A s , i : \mathbf{A}_{s,i:} As,i: 进行stack和sum操作,使之维度保持在 R d \mathbb{R}^{d} Rd 维度上。

a s , i ( t ) = A s , i : [ v 1 ( t − 1 ) , ⋯   , v s n ( t − 1 ) ] H + b (1) \mathbf{a}_{s,i}^{(t)}=\mathbf{A}_{s,i:}[v_1^{(t-1)}, \cdots, v_{s_n}^{(t-1)}] \mathbf{H}+\mathbf{b} \tag{1} as,i(t)=As,i:[v1(t1),,vsn(t1)]H+b(1)

公式(2)完成重置门(Reset Gate)的计算,得到 R d \mathbb{R}^{d} Rd 大小的列向量。

z s , i ( t ) = σ ( W z a s , i ( t ) + U z v i ( t − 1 ) ) (2) \mathbf{z}_{s,i}^{(t)} =\sigma( \mathbf{W}_z \mathbf{a}_{s,i}^{(t)} + \mathbf{U}_z \mathbf{v}_{i}^{(t-1)}) \tag{2} zs,i(t)=σ(Wzas,i(t)+Uzvi(t1))(2)

公式(3)完成更新门(Update Gate)的计算,得到 R d \mathbb{R}^{d} Rd 大小的列向量。

r s , i ( t ) = σ ( W r a s , i ( t ) + U r v i ( t − 1 ) ) (3) \mathbf{r}_{s,i}^{(t)} =\sigma( \mathbf{W}_r \mathbf{a}_{s,i}^{(t)} + \mathbf{U}_r \mathbf{v}_{i}^{(t-1)}) \tag{3} rs,i(t)=σ(Wras,i(t)+Urvi(t1))(3)

公式(4)完成重置向量的计算。

v i ( t ) ~ = t a n h ( W o a s , i ( t ) + U o ( r s , i ( t ) ) ⊙ v i ( t − 1 ) ) ) (4) \widetilde{ \mathbf{v}_{i}^{(t)} } = tanh( \mathbf{W}_o \mathbf{a}_{s,i}^{(t)} + \mathbf{U}_o ( \mathbf{r}_{s, i}^{(t)}) \odot \mathbf{v}_{i}^{(t-1)})) \tag{4} vi(t) =tanh(Woas,i(t)+Uo(rs,i(t))vi(t1)))(4)

公式(5)将 v i ( t − 1 ) \mathbf{v}_{i}^{(t-1)} vi(t1) 1 − z s , i ( t ) 1-\mathbf{z}_{s,i}^{(t)} 1zs,i(t)进行 element-wise 乘积运算, z s , i ( t ) \mathbf{z}_{s,i}^{(t)} zs,i(t) v i ( t ) ~ \widetilde{\mathbf{v}_{i}^{(t)}} vi(t) 进行 element-wise 乘积运算得到最终的 t t t 时刻节点 i i i 的embedding向量。

v i ( t ) = ( 1 − z s , i ( t ) ) ⊙ v i ( t − 1 ) + z s , i ( t ) ⊙ v i ( t ) ~ (5) \mathbf{v}_{i}^{(t)} = (1-\mathbf{z}_{s,i}^{(t)} ) \odot \mathbf{v}_{i}^{(t-1)} + \mathbf{z}_{s,i}^{(t)} \odot \widetilde{\mathbf{v}_{i}^{(t)}} \tag{5} vi(t)=(1zs,i(t))vi(t1)+zs,i(t)vi(t) (5)

Generating Session Embeddings

Session Local Embedding

local embedding使用最后一个节点embedding表示,即
s l o c a l = v s , s n s_{local} = v_{s, s_n} slocal=vs,sn

其中, s n s_n sn 表示session s s s 的长度。

Session Global Embedding

global embedding使用session内所有item的加权之和进行表示,即:
s g l o b a l = ∑ i = 1 s n α i v i (9) s_{global}=\sum_{i=1}^{s_n}{\alpha_i \mathbf{v}_i} \tag{9} sglobal=i=1snαivi(9)

这里, α i \alpha_i αi 表示注意力权重系数,具体如下:
α i = q T σ ( W 1 v s n + W 2 v i + c ) (8) \alpha_i = \mathbf{q}^{\mathsf{T}}\sigma(\mathbf{W}_1\mathbf{v}_{s_n} + \mathbf{W}_2\mathbf{v}_i+\mathbf{c}) \tag{8} αi=qTσ(W1vsn+W2vi+c)(8)

这里的 α i \alpha_i αi 通过当前物品 v i v_i vi 与session 最后一个物品 v s n v_{s_n} vsn 联合计算得出,用于表示当前物品针对用户最后一个点击物品的重要程度。 q \mathbf{q} q W 1 \mathbf{W}_1 W1 W 2 \mathbf{W}_2 W2 c \mathbf{c} c 都是可训练参数。

Session Target Embedding

这里是本文的亮点,实际上就是针对每一个目标候选 v s , s n + 1 v_{s,s_n+1} vs,sn+1 进行注意力系数的计算,从而得到不同的目标候选时session内的每个物品对目标的重要性,从而能够得到不同的session embedding表示,也就是 s t a r g e t t s_{target}^t stargett。具体计算如下:

β i , t = s o f t m a x ( exp ⁡ ( v t T W v i ) ∑ j = 1 m exp ⁡ ( v t T W v j ) ) (6) \beta_{i,t} = \mathsf{softmax}(\frac{ \exp( \mathbf{v}_t^\mathsf{T} \mathbf{W} \mathbf{v}_i)}{ \sum_{j=1}^{m}{\exp( \mathbf{v}_t^\mathsf{T} \mathbf{W} \mathbf{v}_j)} }) \tag{6} βi,t=softmax(j=1mexp(vtTWvj)exp(vtTWvi))(6)

这里, β i , t \beta_{i,t} βi,t 表示物品 i i i 对目标物品 t t t 的注意力系数大小。其中 m m m 表示所有session中的物品集合大小,即 m = ∣ V ∣ m=|V| m=V V V V 表示item集合。

s t a r g e t t = ∑ i = 1 s n β i , t ⋅ v i (7) \mathbf{s}_{target}^t = \sum_{i=1}^{s_n}{\beta_{i,t}\cdot \mathbf{v_i}} \tag{7} stargett=i=1snβi,tvi(7)

这里可以看到,针对一个目标物体 t t t,产生一个 s t a r g e t t s_{target}^t stargett ,其中, t ∈ [ 1 , 2 , ⋯   , m ] t \in [1, 2, \cdots, m] t[1,2,,m]

Session Embedding

s t = W 3 [ s t a r g e t t ; s l o c a l ; s g l o b a l ] (10) \mathbf{s}_t = \mathbf{W}_3 [s_{target}^t;s_{local};s_{global}] \tag{10} st=W3[stargett;slocal;sglobal](10)

由此,我们可以看到模型的弊端:
模型无法胜任大型数据集,随着数据集中节点个数的增加, s t a r g e t t s_{target}^t stargett s t \mathbf{s}_t st 线性增加,从而占据较大空间

但是具体占据空间大小还有待验证,欢迎评论区进行讨论。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值