Sampled Softmax


sampled softmax原论文:On Using Very Large Target Vocabulary for Neural Machine Translation
以及tensorflow关于candidate sampling的文档:candidate sampling


1. 问题背景

在神经机器翻译中,训练的复杂度以及解码的复杂度和词汇表的大小成正比。当输出的词汇表巨大时,传统的softmax由于要计算每一个类的logits就会有问题。在论文Neural Machine Translation by Jointly Learning to Align and Translate 中,带有attention的decoder中权重的公式如下:

其中的 a a a为一个单层的前馈神经网络,根据 α t \alpha_t αt和输入的因状态,我们就可以得到一个context vector c t c_t ct。在decoder的t时刻,输出的目标词汇的概率可以使用如下公式计算:

其中, y t − 1 y_{t-1} yt1是上一个次的输出, z t z_t zt为当前decoder的隐状态, c t c_t ct为context vector。
因为我们输出的是一个概率值,所以(6)式的归一化银子 Z Z Z的计算就需要将词汇表当中的logits都计算一遍,这个代价是很大的。
基于此,作者提出了一种采样的方法,使得我们在训练的时候,输出为原来输出的一个子集。(关于其它的解决方法,作者也有提,感兴趣的可以看原文,本篇博客只关注Sampled Softmax)

2. 解决方法

上面已经说过,计算归一化的因子 Z Z Z,因为所用的词太多造成复杂度的上升,那么原文的方法就是使用一个子集 V ′ V' V来近似的计算出 Z Z Z, 假设我们现在已经知道的这个子集,那么之前计算输出的概率公式就为:

好了,那么 V ′ V' V怎么取?

我们看看tensorflow中的文档吧: https://www.tensorflow.org/extras/candidate_sampling.pdf
对于Sampled Softmax的每一个训练样例 ( x i , { t i } ) (x_i,\{t_i\}) (xi,{ti}),我们根据采样函数 Q ( y ∣ x ) Q(y|x) Q(yx),从所有的输出集合中挑选一个小的子集 S i S_{i} Si。要求选择子集的函数和具体的训练样本无关。假设full softmax的输出全集为 L L L, 那么在给定 x i x_i xi的情况下,根据分布 Q Q Q L L L中抽取的子集似然函数为:

然后我们生成一个包含 S i S_i Si和训练目标类的候选集合 V V V
V ′ = S i ∪ t i V'=S_i \cup{t_i} V=Siti
之后我们的训练目标就是找出样本为 V ′ V' V的哪一个类别了。
(感觉还是tensorflow文档说的清楚一点,最初看论文的时候还以为是相当于把一个单词划分到最近的一个类,那样的话,应该会有不同类别的关系啊不然也不make sense啊,但是看tensorflow源码就只有采样的过程啊,笑cry)

3. tensorflow的实现

def sampled_softmax_loss(weights,
                         biases,
                         labels,
                         inputs,
                         num_sampled, # 每一个batch随机选择的类别
                         num_classes, # 所有可能的类别
                         num_true=1, #每一个sample的类别数量
                         sampled_values=None,
                         remove_accidental_hits=True,
                         partition_strategy="mod",
                         name="sampled_softmax_loss"):

tensorflow对于使用的建议:仅仅在训练阶段使用,在inference或者evaluation的时候还是需要使用full softmax。

原文:
This operation is for training only. It is generally an underestimate of
the full softmax loss.
A common use case is to use this method for training, and calculate the full softmax loss for evaluation or inference.

这个函数的主体主要调用了另外一个函数:

logits, labels = _compute_sampled_logits(
      weights=weights,
      biases=biases,
      labels=labels,
      inputs=inputs,
      num_sampled=num_sampled,
      num_classes=num_classes,
      num_true=num_true,
      sampled_values=sampled_values,
      subtract_log_q=True,
      remove_accidental_hits=remove_accidental_hits,
      partition_strategy=partition_strategy,
      name=name)
 

上述函数的返回值shape为:[batch_size, num_true + num_sampled]即可能的class为: S i ∪ t i S_i \cup{t_i} Siti
而这个函数采样集合的代码如下:

sampled_values=candidate_sampling_ops.log_uniform_candidate_sampler(
          true_classes=labels,# 真实的label
          num_true=num_true,
          num_sampled=num_sampled, # 需要采样的子集大小
          unique=True,
          range_max=num_classes)

而这个函数主要是按照log-uniform distribution(Zipfian distribution)来采样出一个子集,Zipfian distribution
即Zipf法则,以下为Wikipedia关于Zipf’s law的解释:

Zipf’s law states that given some corpus of natural language utterances, the frequency of any word is inversely proportional to its rank in the frequency table.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值