当Top-k遇到深度学习

本文探讨了如何将不可微的Top-k操作融入深度学习框架。通过将Top-k问题转化为最优运输问题并添加正则熵,提出了一个名为SOFT top-k的平滑近似运算符,使得模型可以端到端训练。实验表明,这种方法在图像分类和自然语言生成任务中提高了性能。
摘要由CSDN通过智能技术生成

点击蓝字

关注我们

AI TIME欢迎每一位AI爱好者的加入!

top-k操作(即从分数集合中找到k个最大或最小元素)是一个重要的机器学习模型组件,被广泛用于信息检索和数据挖掘中。但是,如果top-k操作是通过算法方式(例如使用冒泡算法)计算的,则无法使用现在流行的梯度下降算法以端到端的方式训练所得模型。这是因为这些计算方式通常涉及交换索引,无法用来计算其梯度。换句话说,从输入数据到该元素是否属于前k个集合的指标向量的对应映射是不连续的。

为了解决这个问题,我们提出了一个平滑的近似操作,即SOFT top-k运算符。具体来说,我们的SOFT top-k运算符将top-k运算的输出近似为最优传输问题的解。然后,我们基于最优传输问题的KKT条件快速地估算SOFT运算符的梯度。我们将提出的算子应用于k nearest neighbors分类和beam search算法,并通过实验展示了性能的提高。

谢雨佳:本科毕业于中国科学技术大学少年班学院,现为佐治亚理工学院CSE系第五年博士生,导师为查宏远教授和赵拓教授。她的研究方向主要为最优传输理论和端到端学习。

一、动机:如何将Top K 嵌入到深度学习框架中?

k nearest neighbors (kNN) classifier 是一个非常常见且实用的分类方法。具体来说,假设有很多已知label的template data,以及一个未知label的query data,我们可以将未知的query data与其他image相比,得到比较相似的k个图片,将其称为k nearest neighbors,并将这些 neighbors的label作为该未知data的label。对于image data,一个自然的做法是使用特征抽取网络(feature extraction)将这些image投影到一个embedding space里,并在这个embedding space中做kNN以得到对未知data的预测。然后我们最小化损失函数来更新特征抽取网络中的参数。但是,该模型框架的缺点在于:由于需要进行的top-k操作不可微,因此不能通过梯度下降法或者随机梯度下降法实现参数优化。

图1  深度k近邻网络结构示意图

为什么top-k 不可微?考虑一个top-k算

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值