1 目的
在分类任务中,如果分类类别量级较大,几十万甚至上百万的量级,那么最后一层分类层计算将会十分耗时,为了降低模型计算复杂度,每次前向计算,采样部分label参数 w w w参与计算,反向计算梯度的时候也只更新部分参与计算的部分 w w w,而不需要每次更新全部的权重参数 w w w,这样可以大大提高模型的训练速度。
2 tensorflow源码解读
2.1函数的输入和输出
nn_impl.py这是NCE loss的tensorflow源代码,接下来我们对源代码进行一个梳理和讲解,首先我们来看下nce_loss在tensorflow源码中的实现,如下代码所示:
def nce_loss(weights,
biases,
labels,
inputs,
num_sampled,
num_classes,
num_true=1,
sampled_values=None,
remove_accidental_hits=False,
partition_strategy="mod",
name="nce_loss"):
#计算采样的labels和对应的logits(wx+b)值
logits, labels = _compute_sampled_logits(
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
num_sampled=num_sampled,
num_classes=num_classes,
num_true=num_true,
sampled_values=sampled_values,
subtract_log_q=True,
remove_accidental_hits=remove_accidental_hits,
partition_strategy=partition_strategy,
name=name)
# 交叉熵loss
sampled_losses = sigmoid_cross_entropy_with_logits(
labels=labels, logits=logits, name="sampled_losses")
# 返回loss求和
return _sum_rows(sampled_losses)
函数输入参数:
- weights: [num_classes, dim],最后一层分类层的权重参数 w w w
- biases: [num_classes],最后一层分类层的偏移 b b b
- labels: [batch_size, num_true],int64类型,每个batch的正label索引idex,要求每个样本的正label数量必须一致,值为num_true (这也是在实际应用中不灵活的部分,后面会有改进的方案)
- inputs: [batch_size, dim],输入分类层特征向量
- num_sampled: int类型,每个样本随机采样的负样本个数
- num_classes: int类型,分类的label总数量
- num_true: int类型,每个样本的正label数量(一个batch里的所有样本的正label必须一致)
- sampled_values: 自定义采样的候选集,是个三元组 (采样的候选集,正label数量,采样的label数量),默认是None,采用log_uniform_candidate_sampler采用器
- remove_accidental_hits: bool类型,是否去除采样到的label有在正label集合里的,设置为“True"则会用负采样loss而不是NCE。
- partition_strategy: 两种模式“mod"和”div",默认”mod",详情可以参考tf.nn.embedding_lookup
- name: operation的名称
函数返回值:
一维的向量,长度大小为[batch_size],对应每个样本的loss值。
2.2 代码讲解
接下来我们对每个函数的实现做一个深入分析,由上可知,nce_loss函数下主要有三个函数组成,_compute_sampled_logits,sigmoid_cross_entropy_with_logits和_sum_rows。
_compute_sampled_logits函数
def _compute_sampled_logits(weights,
biases,
labels,
inputs,
num_sampled,
num_classes,
num_true=1,
sampled_values=None,
subtract_log_q=True,
div_flag=True,
remove_accidental_hits=False,
partition_strategy="mod",
name=None,
seed=None):
#数据格式转
if isinstance(weights, variables.PartitionedVariable):
weights = list(weights)
if not isinstance(weights, list):
weights = [weights]
# 数据格式转换,将label [batch_size, num_ture] 展开,得到一维的size
with ops.name_scope(name, "compute_sampled_logits",
weights + [biases, inputs, labels]):
if labels.dtype != dtypes.int64:
labels = math_ops.cast(labels, dtypes.int64)
labels_flat = array_ops.reshape(labels, [-1])
#如果采样label不传入,则默认用log_unifrom_candidate_sampler采样器,生成采用的label
if sampled_values is None:
sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
true_classes=labels,
num_true=num_true,
num_sampled=num_sampled,
unique=True,
range_max=num_classes,
seed=seed)
# sampled:[num_sampled],true_expected_count:[batch_size,1]
# sampled_expected_count: [num_sampled]
# 采样的值不参与梯度更新,所以用stop_gradient标明
sampled, true_expected_count, sampled_expected_count = (
array_ops.stop_gradient(s) for s in sampled_values)
sampled = math_ops.cast(sampled, dtypes.int64)
# labels_flat:[batch_size * num_true],sampled: [num_sampled]
#将正label和负label对应的索引合并到一起
all_ids = array_ops.concat([labels_flat, sampled], 0)
#通过索引all_ids从权重参数矩阵weights:[num_classes, dim],取出对应的权重参数,得到all_w
all_w = embedding_ops.embedding_lookup(
weights, all_ids,partition_strategy=partition_strategy)
if all_w.dtype != inputs.dtype:
all_w = math_ops.cast(all_w, inputs.dtype)
# 抽离出正label w权重参数 true_w,和负label权重参数sampled_w
#true_w :[batch_size * num_true, dim]
# sampled_w: [num_sampled, dim], 一个batch里,每个样本的负label都是一样的
true_w = array_ops.slice(all_w, [0, 0],
array_ops.stack(
[array_ops.shape(labels_flat)[0], -1]))
sampled_w = array_ops.slice(
all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
#在对应的负label上,计算wx+b,inputs: [batch_size, dim]
# sampled_logits: [batch_size, num_sampled]
sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True)
# 与计算all_w一样,抽取偏移all_b
all_b = embedding_ops.embedding_lookup(
biases, all_ids, partition_strategy=partition_strategy)
if all_b.dtype != inputs.dtype:
all_b = math_ops.cast(all_b, inputs.dtype)
# 抽离出正,负偏移b
true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])
# inputs: [batch_size, dim]
# true_w: [batch_size * num_true, dim]
# 计算wx+b,得到true_logits:[ batch_size, num_true]
dim = array_ops.shape(true_w)[1:2]
new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
# 做点乘,得到row_wise_dots: [batch_size, num_true, dim]
row_wise_dots = math_ops.multiply(
array_ops.expand_dims(inputs, 1),
array_ops.reshape(true_w, new_true_w_shape))
#reshape
dots_as_matrix = array_ops.reshape(row_wise_dots,
array_ops.concat([[-1], dim], 0))
# 得到正label对应的logits值 [batch_size, num_true]
true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
# +b
true_b = array_ops.reshape(true_b, [-1, num_true])
true_logits += true_b
sampled_logits += sampled_b
################## 此段代码去掉采样的label在正label里 ###########
if remove_accidental_hits:
acc_hits = candidate_sampling_ops.compute_accidental_hits(
labels, sampled, num_true=num_true)
acc_indices, acc_ids, acc_weights = acc_hits
# This is how SparseToDense expects the indices.
acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
acc_ids_2d_int32 = array_ops.reshape(
math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,"sparse_indices")
# Create sampled_logits_shape = [batch_size, num_sampled]
sampled_logits_shape = array_ops.concat(
[array_ops.shape(labels)[:1],
array_ops.expand_dims(num_sampled, 0)], 0)
if sampled_logits.dtype != acc_weights.dtype:
acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
sampled_logits += gen_sparse_ops.sparse_to_dense(
sparse_indices,
sampled_logits_shape,
acc_weights,
default_value=0.0,
validate_indices=False)
############# 此段代码表示是logits否减去-log(true_expected_count) #######
if subtract_log_q:
# Subtract log of Q(l), prior probability that l appears in sampled.
true_logits -= math_ops.log(true_expected_count)
sampled_logits -= math_ops.log(sampled_expected_count)
#将正负logits concat到一起,得到out_logits: [batch_size, num_true+num_sampled]
out_logits = array_ops.concat([true_logits, sampled_logits], 1)
# 标签labels,正label为1/num_true,保证总和为1,负label标签为0
# out_labels: [batch_size, num_ture+num_sampled]
out_labels = array_ops.concat([
array_ops.ones_like(true_logits) / num_true,
array_ops.zeros_like(sampled_logits) ], 1)
return out_logits, out_labels
sigmoid_cross_entropy_with_logits函数
这就不多说了,交叉熵loss,因为label有多个,是multi-label分类,所以用sigmoid,要注意的一点是,函数参数logits传入的值是原始的wx+b值,sigmoid计算在函数里面操作。
_sum_rows函数
def _sum_rows(x):
#该函数的类似tf.reduce_sum(x,1)操作
#官方给出用这样计算的理由是,计算梯度效率更高
cols = array_ops.shape(x)[1]
ones_shape = array_ops.stack([cols, 1])
ones = array_ops.ones(ones_shape, x.dtype)
# x:[batch_size, num_true+num_sampled]
# ones: [num_true+num_sampled, 1]
#x和ones两个矩阵相乘,得到[batch_size,1],再reshape [batch_size]
return array_ops.reshape(math_ops.matmul(x, ones), [-1])
2.3 缺点
从tensorflow源代码知道,要求每个输入的batch的正label个数必须一致,个数为num_true,所以正常训练模型的时候,必须每一个batch的样本正label一样,但是在实际应用中,特别是multi-label分类,每个样本的正label个数很多是不一致的,在multi task任务中,更不能保证一个batch在多个任务的label标签上都是一致的。
3 正label个数不一致解决方案
针对上述缺陷,尝试如下方案 ,已试验可行。
3.1 增加一个pad标签作为负label
核心思想:生成样本的时候,将每个样本的label统一长度为num_true,不足的,用索引为0 (代表pad) 的标签填充,在计算loss的时候,让pad类别对应为负label。
修改代码主要如下:
#修改前函数
def _compute_sampled_logits(...):
...
out_logits = array_ops.concat([true_logits, sampled_logits], 1)
# 对应的源代码生成label过程
out_labels = array_ops.concat([
array_ops.ones_like(true_logits) / num_true,
array_ops.zeros_like(sampled_logits)], 1)
return out_logits, out_labels
#修改后函数
def _compute_sampled_logits(...):
...
out_logits = array_ops.concat([true_logits, sampled_logits], 1)
# 生成mask矩阵,其中真实的正label元素为1, 填充pad label为0
mask = tf.cast(tf.not_equal(labels, 0), tf.float32)
# 将pad的label都为负label 0
true_y = array_ops.ones_like(true_logits) * mask
# 然后用div_flag控制是否需要对每个样本的label除以每个样本的个数
# 这里动态的计算每个样本的真实label数量,因为每个样本pad的个数不一致
if div_flag:
dynamic_num_true = tf.reduce_sum(tf.sign(labels), reduction_indices=1)
dynamic_num_true = tf.cast(tf.expand_dims(dynamic_num_true, -1), tf.float32)
true_y = tf.div(true_y, dynamic_num_true)
# 将正label和负label组合,得到out_labels返回
out_labels = array_ops.concat([
true_y,
array_ops.zeros_like(sampled_logits)], 1)
return out_logits, out_labels
4 参考
Noise-contrastive estimation: A new estimation principle for
unnormalized statistical models