噪声对比估计NCE (Noise-contrastive estimation)采样方法,提高训练速度,解决源码中正label个数必须相等问题

1 目的

在分类任务中,如果分类类别量级较大,几十万甚至上百万的量级,那么最后一层分类层计算将会十分耗时,为了降低模型计算复杂度,每次前向计算,采样部分label参数 w w w参与计算,反向计算梯度的时候也只更新部分参与计算的部分 w w w,而不需要每次更新全部的权重参数 w w w,这样可以大大提高模型的训练速度。

2 tensorflow源码解读

2.1函数的输入和输出

nn_impl.py这是NCE loss的tensorflow源代码,接下来我们对源代码进行一个梳理和讲解,首先我们来看下nce_loss在tensorflow源码中的实现,如下代码所示:

def nce_loss(weights,
             biases,
             labels,
             inputs,
             num_sampled,
             num_classes,
             num_true=1,
             sampled_values=None,
             remove_accidental_hits=False,
             partition_strategy="mod",
             name="nce_loss"):
     #计算采样的labels和对应的logits(wx+b)值
	logits, labels = _compute_sampled_logits(
      weights=weights,
      biases=biases,
      labels=labels,
      inputs=inputs,
      num_sampled=num_sampled,
      num_classes=num_classes,
      num_true=num_true,
      sampled_values=sampled_values,
      subtract_log_q=True,
      remove_accidental_hits=remove_accidental_hits,
      partition_strategy=partition_strategy,
      name=name)
     # 交叉熵loss
	sampled_losses = sigmoid_cross_entropy_with_logits(
      labels=labels, logits=logits, name="sampled_losses")
     # 返回loss求和
	return _sum_rows(sampled_losses)

函数输入参数

  • weights: [num_classes, dim],最后一层分类层的权重参数 w w w
  • biases: [num_classes],最后一层分类层的偏移 b b b
  • labels: [batch_size, num_true],int64类型,每个batch的正label索引idex,要求每个样本的正label数量必须一致,值为num_true (这也是在实际应用中不灵活的部分,后面会有改进的方案)
  • inputs: [batch_size, dim],输入分类层特征向量
  • num_sampled: int类型,每个样本随机采样的负样本个数
  • num_classes: int类型,分类的label总数量
  • num_true: int类型,每个样本的正label数量(一个batch里的所有样本的正label必须一致)
  • sampled_values: 自定义采样的候选集,是个三元组 (采样的候选集,正label数量,采样的label数量),默认是None,采用log_uniform_candidate_sampler采用器
  • remove_accidental_hits: bool类型,是否去除采样到的label有在正label集合里的,设置为“True"则会用负采样loss而不是NCE。
  • partition_strategy: 两种模式“mod"和”div",默认”mod",详情可以参考tf.nn.embedding_lookup
  • name: operation的名称

函数返回值
一维的向量,长度大小为[batch_size],对应每个样本的loss值。

2.2 代码讲解

接下来我们对每个函数的实现做一个深入分析,由上可知,nce_loss函数下主要有三个函数组成,_compute_sampled_logits,sigmoid_cross_entropy_with_logits和_sum_rows。

_compute_sampled_logits函数


def _compute_sampled_logits(weights,
                            biases,
                            labels,
                            inputs,
                            num_sampled,
                            num_classes,
                            num_true=1,
                            sampled_values=None,
                            subtract_log_q=True,
                            div_flag=True,
                            remove_accidental_hits=False,
                            partition_strategy="mod",
                            name=None,
                            seed=None):

   #数据格式转
  if isinstance(weights, variables.PartitionedVariable):
    weights = list(weights)
  if not isinstance(weights, list):
    weights = [weights]
  # 数据格式转换,将label [batch_size, num_ture] 展开,得到一维的size
  with ops.name_scope(name, "compute_sampled_logits",
                      weights + [biases, inputs, labels]):
    if labels.dtype != dtypes.int64:
      labels = math_ops.cast(labels, dtypes.int64)
    labels_flat = array_ops.reshape(labels, [-1])
   #如果采样label不传入,则默认用log_unifrom_candidate_sampler采样器,生成采用的label
    if sampled_values is None:
        sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
          true_classes=labels,
          num_true=num_true,
          num_sampled=num_sampled,
          unique=True,
          range_max=num_classes,
          seed=seed)
    # sampled:[num_sampled],true_expected_count:[batch_size,1]
    # sampled_expected_count: [num_sampled]
    # 采样的值不参与梯度更新,所以用stop_gradient标明
    sampled, true_expected_count, sampled_expected_count = (
        array_ops.stop_gradient(s) for s in sampled_values)
    sampled = math_ops.cast(sampled, dtypes.int64)
    # labels_flat:[batch_size * num_true],sampled: [num_sampled]
    #将正label和负label对应的索引合并到一起
    all_ids = array_ops.concat([labels_flat, sampled], 0)
    #通过索引all_ids从权重参数矩阵weights:[num_classes, dim],取出对应的权重参数,得到all_w
    all_w = embedding_ops.embedding_lookup(
        weights, all_ids,partition_strategy=partition_strategy)
    if all_w.dtype != inputs.dtype:
      all_w = math_ops.cast(all_w, inputs.dtype)
    # 抽离出正label w权重参数 true_w,和负label权重参数sampled_w
    #true_w :[batch_size * num_true, dim]
    # sampled_w: [num_sampled, dim], 一个batch里,每个样本的负label都是一样的
    true_w = array_ops.slice(all_w, [0, 0],
                             array_ops.stack(
                                 [array_ops.shape(labels_flat)[0], -1]))
    sampled_w = array_ops.slice(
        all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
    #在对应的负label上,计算wx+b,inputs: [batch_size, dim]
    # sampled_logits: [batch_size, num_sampled]
    sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True)
    # 与计算all_w一样,抽取偏移all_b
    all_b = embedding_ops.embedding_lookup(
        biases, all_ids, partition_strategy=partition_strategy)
    if all_b.dtype != inputs.dtype:
      all_b = math_ops.cast(all_b, inputs.dtype)
    # 抽离出正,负偏移b
    true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
    sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])
    # inputs: [batch_size, dim]
    # true_w: [batch_size * num_true, dim]
    # 计算wx+b,得到true_logits:[ batch_size, num_true]
    dim = array_ops.shape(true_w)[1:2]
    new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
    # 做点乘,得到row_wise_dots: [batch_size, num_true, dim]
    row_wise_dots = math_ops.multiply(
        array_ops.expand_dims(inputs, 1),
        array_ops.reshape(true_w, new_true_w_shape))
    #reshape
    dots_as_matrix = array_ops.reshape(row_wise_dots,
                                       array_ops.concat([[-1], dim], 0))
    # 得到正label对应的logits值 [batch_size, num_true]                                                               
    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
    # +b
    true_b = array_ops.reshape(true_b, [-1, num_true])
    true_logits += true_b
    sampled_logits += sampled_b
################## 此段代码去掉采样的label在正label里  ###########
    if remove_accidental_hits:
      acc_hits = candidate_sampling_ops.compute_accidental_hits(
          labels, sampled, num_true=num_true)
      acc_indices, acc_ids, acc_weights = acc_hits

      # This is how SparseToDense expects the indices.
      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
      acc_ids_2d_int32 = array_ops.reshape(
          math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
      sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,"sparse_indices")
      # Create sampled_logits_shape = [batch_size, num_sampled]
      sampled_logits_shape = array_ops.concat(
          [array_ops.shape(labels)[:1],
           array_ops.expand_dims(num_sampled, 0)], 0)
      if sampled_logits.dtype != acc_weights.dtype:
        acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
      sampled_logits += gen_sparse_ops.sparse_to_dense(
          sparse_indices,
          sampled_logits_shape,
          acc_weights,
          default_value=0.0,
          validate_indices=False)
############# 此段代码表示是logits否减去-log(true_expected_count) #######
    if subtract_log_q:
      # Subtract log of Q(l), prior probability that l appears in sampled.
      true_logits -= math_ops.log(true_expected_count)
      sampled_logits -= math_ops.log(sampled_expected_count)
	#将正负logits concat到一起,得到out_logits: [batch_size, num_true+num_sampled]
	out_logits = array_ops.concat([true_logits, sampled_logits], 1)
	# 标签labels,正label为1/num_true,保证总和为1,负label标签为0
	# out_labels: [batch_size, num_ture+num_sampled]
    out_labels = array_ops.concat([
        array_ops.ones_like(true_logits) / num_true,
        array_ops.zeros_like(sampled_logits) ], 1)
	return out_logits, out_labels

sigmoid_cross_entropy_with_logits函数
这就不多说了,交叉熵loss,因为label有多个,是multi-label分类,所以用sigmoid,要注意的一点是,函数参数logits传入的值是原始的wx+b值,sigmoid计算在函数里面操作。

_sum_rows函数

def _sum_rows(x):
  #该函数的类似tf.reduce_sum(x,1)操作
  #官方给出用这样计算的理由是,计算梯度效率更高
  cols = array_ops.shape(x)[1]
  ones_shape = array_ops.stack([cols, 1])
  ones = array_ops.ones(ones_shape, x.dtype)
  # x:[batch_size, num_true+num_sampled]
  # ones: [num_true+num_sampled, 1]
  #x和ones两个矩阵相乘,得到[batch_size,1],再reshape [batch_size]
  return array_ops.reshape(math_ops.matmul(x, ones), [-1])

2.3 缺点

从tensorflow源代码知道,要求每个输入的batch的正label个数必须一致,个数为num_true,所以正常训练模型的时候,必须每一个batch的样本正label一样,但是在实际应用中,特别是multi-label分类,每个样本的正label个数很多是不一致的,在multi task任务中,更不能保证一个batch在多个任务的label标签上都是一致的。

3 正label个数不一致解决方案

针对上述缺陷,尝试如下方案 ,已试验可行。

3.1 增加一个pad标签作为负label

核心思想生成样本的时候,将每个样本的label统一长度为num_true,不足的,用索引为0 (代表pad) 的标签填充,在计算loss的时候,让pad类别对应为负label
修改代码主要如下:

#修改前函数
def _compute_sampled_logits(...):
	...
	out_logits = array_ops.concat([true_logits, sampled_logits], 1)
	# 对应的源代码生成label过程
	out_labels = array_ops.concat([
    	array_ops.ones_like(true_logits) / num_true,
    	array_ops.zeros_like(sampled_logits)], 1)
    return out_logits, out_labels

#修改后函数
def _compute_sampled_logits(...):
	...
	out_logits = array_ops.concat([true_logits, sampled_logits], 1)
	# 生成mask矩阵,其中真实的正label元素为1, 填充pad label为0
    mask = tf.cast(tf.not_equal(labels, 0), tf.float32)
    # 将pad的label都为负label 0
    true_y = array_ops.ones_like(true_logits) * mask 
    # 然后用div_flag控制是否需要对每个样本的label除以每个样本的个数
    # 这里动态的计算每个样本的真实label数量,因为每个样本pad的个数不一致
    if div_flag:
        dynamic_num_true = tf.reduce_sum(tf.sign(labels), reduction_indices=1)
        dynamic_num_true = tf.cast(tf.expand_dims(dynamic_num_true, -1), tf.float32)
        true_y = tf.div(true_y, dynamic_num_true)
    # 将正label和负label组合,得到out_labels返回
    out_labels = array_ops.concat([
        true_y,
        array_ops.zeros_like(sampled_logits)], 1)
 	return out_logits, out_labels  

4 参考

Noise-contrastive estimation: A new estimation principle for
unnormalized statistical models

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值