![0fa5332597d73e47a1cb3b04c661d16b.png](https://i-blog.csdnimg.cn/blog_migrate/f42a8a0bd759fe752c41eace4d7b26d9.jpeg)
@[toc]
前言
夜小白:基于表征(Representation)文本匹配、信息检索、向量召回的方法总结(用于召回、或者粗排)zhuanlan.zhihu.com![05b5a6b5a6d0b521fd60211ee3bcd869.png](https://i-blog.csdnimg.cn/blog_migrate/b13f9eaf9aef70c14ea2bd9dcb0bf2d2.jpeg)
![00e9c98840bb393886a707a8ce3be307.png](https://i-blog.csdnimg.cn/blog_migrate/60f79a668c24e5b75f77891441488418.jpeg)
前面两篇关于文本匹配的博客中,都用到了Sampled-softmax训练方法来加速训练,Sampled-softmax简单点来说,就是通过采样,来减少我们训练计算loss时输出层的运算量。从第一篇博客中的不知其然,到后面看到DSSM代码中Sampled softamax的知其然,这篇博客目的是在知其所以然,从Sampled softmax的数学原理思考,为什么DSSM中的训练代码可以这样写,代码还能怎么改进。
这段时间也一直在思考,如何才能不随波逐流,如何才能成为一名独当一面的算法工程师,我想对于一个问题的浅尝辄止肯定是远远不够的,不仅要知其然还要知其所以然,光是读懂这几篇论文是不够的,进一步的要理解代码工程实现,更进一步,去理解代码背后的数学原理,为什么代码这样做一定能保证结果正确或者收敛,了解了这些,我们才能够根据自己的想法去做优化,我想对于现在日益成熟的深度学习,难的可能不是如何实现,而是对于自己的实际场景去调整优化。
上面有点扯远了,回归正题,这篇博客主要基于Tensorflow官方对于Sampled softmax文档,建议大家有问题不懂的时候多看官方文档,写的非常通俗易懂,下面我就说说自己对Sampled Softmax数学原理的理解。
What is Candidate Sampling Tensorflow 官方文档
什么是Sampled Softmax
1、logits与softmax
当我们做分类问题时,假设我们需要分类的类别数为
- 神经网络最后一层输出层神经元个数为
,每个神经元输出分别表示
logits
, 这里的logits
其实代表的就是各个类别未经归一化的概率分布(也就是加起来不为1),网络就是学习出一个映射 - 将上述输出的
logits
作为softmax
的输入进行归一化操作,softmax
的输出则是表示各个类别上的概率分布 - 根据这个概率分布计算损失函数,如交叉熵损失
还是采用之前博客中的Query-Doc Softmax作为说明,从logtis
进行softmax
归一化公式如下:
![f88a740e6e931a6df01ed29c402bbb9e.png](https://i-blog.csdnimg.cn/blog_migrate/7a466decc6da0870bd8620498996a5b6.png)
-
表示我们的输入,
表示我们的模型,
即是给定
情况下,输出类别为
的
logits
- 我们注意分母中
即为所有文档集合,也就是我们的总类别数
这个公式的具体解释可以参考之前的两篇博客,下面分析一下上面这个公式,下面是重点:
- 当我们类别数非常大时,也就是
非常大时,那么我们分母的计算量就会非常大,因为需要在整个类别全集上求和。比如假设我们有100W个文档,那么如果我们不做任何处理,
softmax
归一化 - 我们如果对每个类别
logits
加上一个与类别无关的常数,结果将不会变化。这个很好理解,当我们对每个logits
均加上同一个常数K
,那么分子分母可以约去这个常数K
,结果不变 * - 分母其实是一个归一化因子,如果看过PRML同学应该熟悉,有点类似于指数族分布中的
partition function
,分母与类别无关,因为分母中对整个类别集合进行了求和,给定输入后,分母归一化因子也就确定了。
从上面分析可以知道,我们的关键词是logits
、softmax归一化
。logits
本质上就是未归一化的概率,softmax
目的就是计算归一化因子(分母),对logtis
进行归一化,从而得到一个概率分布。问题就在于需要对整个类别集合
logtis
并求和,当类别集合比较大时(比如上面的Query-Doc预测,以及语言模型训练),计算量会非常大。
2、Sampled Softmax
Sampled Softmax
的核心思想就在于 Sampled
,既然类别全集太大,那么能不能采样一个类别子集,然后在计算在子集上的logtis
然后进行softmax
归一化呢?假设我们类别全集为
我们在训练模型时,只要在这个采样出来的
logits
和
softmax
就可以了,大大减少了计算量,加快训练过程。现在问题是:
- *当我们进行采样之后,各个类别
logits
应该如何计算,和使用类别全集时的logtis
有什么对应关系?
Sampled Softmax背后的数学原理
从上面可以看出,当我们进行采样后,按理来说logtis
计算方法也需要改变,这样才能最后得到正确的概率分布。前方公式预警!!!!
1、数学符号约定
-
表示我们的一个训练样本,
为输入模型的特征,
为标签,目标类别
-
给定输入
,输出类别为
的条件概率
-
给定输入
,输出类别为
的
logtis
,这里其实表示的就是我们的模型
-
类别全集
-
采样函数,给定输入
,采样出类别
的概率
-
采样出来的类别子集
以上符号如果没有特殊说明,都表示是在类别全集上进行计算
2、logits与概率之间的关系
![e4255caad9fba731ad7c2496892a6d57.png](https://i-blog.csdnimg.cn/blog_migrate/2a050f01b599746757d1453219abc649.png)
其中
softmax
计算出来的分母。推导也很简单:
最后将
logits
可以写成“
3、采样出类别子集
![e682682d6307687a6fb51e6a43a8d71c.png](https://i-blog.csdnimg.cn/blog_migrate/5b60526eb8bce486921f6551be51002c.png)
这里推导也很简单,当
4、计算采样后类别子集
重点来了!前面都是铺垫,我们最终的目的是计算给定输入
logits
的正确表示形式啦~,我们假设
![dd5f30d6bc2c3d9a731a9de3a8599ccc.png](https://i-blog.csdnimg.cn/blog_migrate/b8063536c1676cc15a0471842eb41f44.png)
上面的推导就是简单的贝叶斯公式。我们分析一下推导结果:
-
这个就是在类别全集情况下,给定输入
,输出类别为
的条件概率
-
这个概率就是给定类别
,输入
情况下,采样出类别子集
的概率,这个计算方式已经在3中,
![fbe8baa81a704d12d7407ae6ce6a1c15.png](https://i-blog.csdnimg.cn/blog_migrate/bafdab013d189fd863b169c75e18ac17.png)
-
这其实是个和输出类别
无关的常量,可以视为const
综上,下面
![0a57bceff14eaeb6e42f6822ee9b1ddb.png](https://i-blog.csdnimg.cn/blog_migrate/57dc7c3d2bc2ffe4c6746ae903674bc0.png)
其中
![77c18839a76745fb656de8f3ebd9887c.png](https://i-blog.csdnimg.cn/blog_migrate/366d30decf745d32ed9f06aa0f3d6391.png)
结果已经跃然纸上,
5、采样后类别子集
logits
和原始
logits
关系
终于要到最后一步了,我们已经知道了采样后类别子集
logits
和原始
logits
关系,推导如下:
其中与类别
大功告成!上面的公式就是我们进行采样后的logtis
与原始logits
关系,具体的用法如下:
- 通过
对类别进行采样,得到一个类别子集
- 模型对采样类别子集
中的类别分别计算
logits
(这样就不用在类别全集计算logits
了),这里得到的其实是 - 对于计算出来的
,减去
,就得到了我们采样后子集的
logits
, - 使用
作为
softmax
输入,计算概率分布以及loss进行梯度下降
DSSM Sampled Softmax 分析
从上面分析可以得到:
我们选取不同的采样函数
tf.nn.log_uniform_candidate_sampler
,按照 log-uniform (Zipfian) 分布采样。
![56fdc63963109c959769d80ef8289c7a.png](https://i-blog.csdnimg.cn/blog_migrate/9b265db1b5480f3113d0f5e029acc8e5.png)
tf.nn.learned_unigram_candidate_sampler
按照训练数据中类别出现分布进行采样。具体实现方式:1)初始化一个 [0, range_max] 的数组, 数组元素初始为1; 2) 在训练过程中碰到一个类别,就将相应数组元素加 1;3) 每次按照数组归一化得到的概率进行采样。
上述采样方式都和输入
logits
加上或者减去一个常数,对
softmax
结果并没有影响,所以可以用
原始logits
代替采样后的logits
。所以DSSM代码中,构造子集后直接计算
logits
然后做
softmax
结果也是正确的,代码如下:
with tf.name_scope('Loss'):
# Train Loss
# 转化为softmax概率矩阵。
prob = tf.nn.softmax(cos_sim)
# 只取第一列,即正样本列概率。相当于one-hot标签为[1,0,0,0,.....,0]
hit_prob = tf.slice(prob, [0, 0], [-1, 1])
loss = -tf.reduce_sum(tf.log(hit_prob))
tf.summary.scalar('loss', loss)
总结
理论指导实践,代码中每一步都是有理论依据的,所以只有弄懂其背后的数学原理才能各个算法活学活用。以上也都是我的个人理解,难免有错,欢迎大家和我讨论,一起学习,一起进步~