深度学习中的对抗损失怎么使用

对前几天的对抗损失总结一下,转载请注明出处,如有不对的地方,欢迎前来指出,一起探讨。

1.对抗损失的目的与作用

    对抗损失的使用主要是为了减少标注数据,在真实的业务中,对于数据的标注是一件非常头疼的事,为了使用1000条标注能够达到2000条标注数据的所能达到效果(打个比方),模拟真实世界中各种噪声的情况,让模型更加鲁棒,更好用,准确率更高,在图像处理中经常使用引入噪声来增加图像的样本集,在文本类数据中怎么来使用呢?

    那具体怎么做呢?一般方法是将真实数据+噪声,使得数据散度更大,让有标签与无标签的数据在模型上一起进行学习。

2.对抗损失的方法

  • 一般对抗损失(在进行一般对抗的时候,首先对要对batch_size的数据进行一次损失的计算,根据损失对输入(可以是字向量、词向量、词性向量、偏旁向量等特征的拼接)的向量进行偏导数计算,但要注意的是需要对输入数据进行stop_gradiants,然后刚才求得的Grad进行L2正则化处理,将处理后的值称作扰动perturb,将perturb与embedding一起add,然后在进行一次损失的计算,起到了对每个batch_size的数据进行了两次的计算,增加模型的鲁棒性)
  • 随机对抗损失(随机生成一个形状与embedding相同的向量,然后进行MASK操作,再进行L2正则化生成pertub,最后将生成的噪声添加到输入特征向量上进行再次计算损失)
  • 虚拟对抗损失(虚拟对抗与随机损失有点相似,但是引入了KL散度,具体看下面实现)

3.对抗损失具体实现

一般对抗损失

def adversarial_loss(embedded, loss, loss_fn):
  """Adds gradient to embedding and recomputes classification loss."""
  grad, = tf.gradients(
      loss,
      embedded)
  grad = tf.stop_gradient(grad)
  perturb = _scale_l2(grad, FLAGS.perturb_norm_length)
  return loss_fn(embedded + perturb)

随机对抗损失

def random_perturbation_loss(embedded, length, loss_fn):
  """Adds noise to embeddings and recomputes classification loss."""
  noise = tf.random_normal(shape=tf.shape(embedded))
  perturb = _scale_l2(_mask_by_length(noise, length), FLAGS.perturb_norm_length)
  return loss_fn(embedded + perturb)

虚拟对抗损失

def virtual_adversarial_loss(logits, embedded, inputs,
                             logits_from_embedding_fn):
  """Virtual adversarial loss.

  Computes virtual adversarial perturbation by finite difference method and
  power iteration, adds it to the embedding, and computes the KL divergence
  between the new logits and the original logits.

  Args:
    logits: 3-D float Tensor, [batch_size, num_timesteps, m], where m=1 if
      num_classes=2, otherwise m=num_classes.
    embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim].
    inputs: VatxtInput.
    logits_from_embedding_fn: callable that takes embeddings and returns
      classifier logits.

  Returns:
    kl: float scalar.
  """
  # Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details.
  logits = tf.stop_gradient(logits)

  # Only care about the KL divergence on the final timestep.
  weights = inputs.eos_weights
  assert weights is not None
  if FLAGS.single_label:
    indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1)
    weights = tf.expand_dims(tf.gather_nd(inputs.eos_weights, indices), 1)

  # Initialize perturbation with random noise.
  # shape(embedded) = (batch_size, num_timesteps, embedding_dim)
  d = tf.random_normal(shape=tf.shape(embedded))

  # Perform finite difference method and power iteration.
  # See Eq.(8) in the paper http://arxiv.org/pdf/1507.00677.pdf,
  # Adding small noise to input and taking gradient with respect to the noise
  # corresponds to 1 power iteration.
  for _ in xrange(FLAGS.num_power_iteration):
    d = _scale_l2(
        _mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff)

    d_logits = logits_from_embedding_fn(embedded + d)
    kl = _kl_divergence_with_logits(logits, d_logits, weights)
    d, = tf.gradients(
        kl,
        d,
        aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
    d = tf.stop_gradient(d)

  perturb = _scale_l2(d, FLAGS.perturb_norm_length)
  vadv_logits = logits_from_embedding_fn(embedded + perturb)
  return _kl_divergence_with_logits(logits, vadv_logits, weights)

 

 

 

  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值