对前几天的对抗损失总结一下,转载请注明出处,如有不对的地方,欢迎前来指出,一起探讨。
1.对抗损失的目的与作用
对抗损失的使用主要是为了减少标注数据,在真实的业务中,对于数据的标注是一件非常头疼的事,为了使用1000条标注能够达到2000条标注数据的所能达到效果(打个比方),模拟真实世界中各种噪声的情况,让模型更加鲁棒,更好用,准确率更高,在图像处理中经常使用引入噪声来增加图像的样本集,在文本类数据中怎么来使用呢?
那具体怎么做呢?一般方法是将真实数据+噪声,使得数据散度更大,让有标签与无标签的数据在模型上一起进行学习。
2.对抗损失的方法
- 一般对抗损失(在进行一般对抗的时候,首先对要对batch_size的数据进行一次损失的计算,根据损失对输入(可以是字向量、词向量、词性向量、偏旁向量等特征的拼接)的向量进行偏导数计算,但要注意的是需要对输入数据进行stop_gradiants,然后刚才求得的Grad进行L2正则化处理,将处理后的值称作扰动perturb,将perturb与embedding一起add,然后在进行一次损失的计算,起到了对每个batch_size的数据进行了两次的计算,增加模型的鲁棒性)
- 随机对抗损失(随机生成一个形状与embedding相同的向量,然后进行MASK操作,再进行L2正则化生成pertub,最后将生成的噪声添加到输入特征向量上进行再次计算损失)
- 虚拟对抗损失(虚拟对抗与随机损失有点相似,但是引入了KL散度,具体看下面实现)
3.对抗损失具体实现
一般对抗损失
def adversarial_loss(embedded, loss, loss_fn):
"""Adds gradient to embedding and recomputes classification loss."""
grad, = tf.gradients(
loss,
embedded)
grad = tf.stop_gradient(grad)
perturb = _scale_l2(grad, FLAGS.perturb_norm_length)
return loss_fn(embedded + perturb)
随机对抗损失
def random_perturbation_loss(embedded, length, loss_fn):
"""Adds noise to embeddings and recomputes classification loss."""
noise = tf.random_normal(shape=tf.shape(embedded))
perturb = _scale_l2(_mask_by_length(noise, length), FLAGS.perturb_norm_length)
return loss_fn(embedded + perturb)
虚拟对抗损失
def virtual_adversarial_loss(logits, embedded, inputs,
logits_from_embedding_fn):
"""Virtual adversarial loss.
Computes virtual adversarial perturbation by finite difference method and
power iteration, adds it to the embedding, and computes the KL divergence
between the new logits and the original logits.
Args:
logits: 3-D float Tensor, [batch_size, num_timesteps, m], where m=1 if
num_classes=2, otherwise m=num_classes.
embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim].
inputs: VatxtInput.
logits_from_embedding_fn: callable that takes embeddings and returns
classifier logits.
Returns:
kl: float scalar.
"""
# Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details.
logits = tf.stop_gradient(logits)
# Only care about the KL divergence on the final timestep.
weights = inputs.eos_weights
assert weights is not None
if FLAGS.single_label:
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
weights = tf.expand_dims(tf.gather_nd(inputs.eos_weights, indices), 1)
# Initialize perturbation with random noise.
# shape(embedded) = (batch_size, num_timesteps, embedding_dim)
d = tf.random_normal(shape=tf.shape(embedded))
# Perform finite difference method and power iteration.
# See Eq.(8) in the paper http://arxiv.org/pdf/1507.00677.pdf,
# Adding small noise to input and taking gradient with respect to the noise
# corresponds to 1 power iteration.
for _ in xrange(FLAGS.num_power_iteration):
d = _scale_l2(
_mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff)
d_logits = logits_from_embedding_fn(embedded + d)
kl = _kl_divergence_with_logits(logits, d_logits, weights)
d, = tf.gradients(
kl,
d,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
d = tf.stop_gradient(d)
perturb = _scale_l2(d, FLAGS.perturb_norm_length)
vadv_logits = logits_from_embedding_fn(embedded + perturb)
return _kl_divergence_with_logits(logits, vadv_logits, weights)