谷歌采样修正的双塔模型

贡献

本文提出了一种从流式数据中估计item频率的新算法,通过理论推导,证明了该算法可以在无需固定item词表的情况下生效,并且能够产生无偏估计,同时能够适应item分布的变化。以解决热门商品在负样本采样时,采样次数过多而被过度惩罚。

业内的主流方法和问题

推荐领域中emb学习的挑战通常有两个:1)对于许多工业级别的应用来说item语料规模会相当大。2)采集自用户反馈的训练数据对许多item来说非常稀疏,从而导致模型预测的长尾内容有很大的方差。面对商品冷启动问题,现实世界的系统需要适应数据分布的变化,以更好地获取新鲜item。

双塔网络算法原理

双塔网络与NCF(神经协同过滤)不同,双塔网络降低耗时且修正了模型损失函数。NCF简介
利用双塔模型构架推荐系统,首先建立两个参数embedding函数,把query和候选item映射到k维向量空间,模型的输出为二者的embedding内积。
模型结构如图所示:
YouTube双塔网络

In-batch loss function

推荐问题可以看作是,给定query X,从M个item中得到y的概率可以利用softmax函数计算:在这里插入图片描述
考虑反馈 ri, 加权对数似然损失函数为:
在这里插入图片描述
当M非常大(样本总数很大)时,我们通常可以利用负采样算法进行计算。然而对于流数据,我们考虑在同一个batch中采样负样本。此处可看成是现实场景下,训练数据是分批到达,模型训练也是分batch进行。
batch-softmax函数为:
在这里插入图片描述
在每个batch中,由于存在幂律分布现象。如果在每个batch中随机采样负样本,会使热门商品更容易被采样到,在损失函数中就“过度”惩罚了这些热门商品,因此考虑用频率对采样进行修正,即:
在这里插入图片描述
其中 Pj 是在每个batch中随机采样到item j的概率(将在下一节中介绍),因此修正后的条件概率函数为:在这里插入图片描述

Streaming Frequency Estimation

此方法用于估计在流数据中,每个batch下item出现的概率。上面提到的Pj。

对于一个流式的随机batch,问题是预估每个batch中每个item 的出现概率。一个关键的设计准则是当有多个训练jobs(workers)时,要有一个全局的预估来支持分布式训练。此处可以利用全局step,并对一个item的频率预估转化为deta预估,其表示为两次连续命中item所需的平均step。例如,如果一个item每50step采样一次,deta = 50,则得到p = 0.02。使用全局step提供了两点优势:1)通过读取和修改全局step,多个worker在频率预估中隐式的同步。2)预测通过简单的滑动平均来更新,该更新适用于分布的改变。

为了解决hash collision的问题,可以建立多个数组 Ai Bi 最终在多个数组中取最大。

定义两个大小为H的数组A,B,哈希函数h可以把每个item映射为[H]内的整数。

A[h(y)]表示item y上次被采样到的时刻
B[h(y)]表示每多少步item y可以被采样一次
先说结论,当第t步y被采样到时,利用迭代可更新A,B:
在这里插入图片描述
alpha 可看作学习率。通过上式更新后,则在每个batch中item y出现的概率为 1/B[h(y)]。
直观上,上式可以看作利用SGD算法和固定的学习率 [公式] 来学习“可以多久被采样到一次”这个随机变量的均值。

下面,可以从数学理论上证明这种迭代更新的有效性:
在这里插入图片描述在这里插入图片描述

算法总结

涵盖了In-batch loss function 和 流数据频率估计 的训练算法
训练算法
流数据频率估计 算法
流数据频率估计
改进的多元数组-频率估计算法
为了解决hash collision的问题,可以建立多个数组 Ai Bi 最终在多个数组中取最大。
在这里插入图片描述

归一化&&微调

在这里插入图片描述

参考:https://www.jianshu.com/p/177f49effd50

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值