Incorporating GAN for Negative Sampling in Knowledge Representation Learning笔记

基于GAN生成负样本使表示学习模型学习到更好的embedding

Motivation

知识表示学习在学习embedding过程的目标是最小化margin-based ranking loss,计算loss需要负样本,而以往的工作通过随机的方式构造负样本,生成的样本可能太简单,对模型的能力并不会有太大帮助。本文提出一种基于对抗学习(GAN)的框架获取高质量的负样本提升表示学习模型的能力,其中generator用于生成更好的负样本,discriminator用于学习KG中实体和关系的表示。实验证明将基于GAN框架生成的负样本用于知识表示学习模型在三元组分类和链接预测上的结果优于baseline模型。
针对random的方式,作者认为这会使模型随着训练出现zero loss的问题。模型训练开始时当绝大多数负样本和正样本的距离在margin内,随机采样的方法是有效的;随着训练过程的继续,随机构造的负样本可能会在margin外(红色的圈内),则会出现loss归0的情况。比如三元组(Steve Jobs, FounderOf, Apple Inc.),随机替换头实体时,可能用于替换的是London或baseball,导致zero loss发生。模型只能区分Steve Jobs不同于non-person的实体,而不能学习到这个实体的概念。
在这里插入图片描述

KGE模型训练介绍

KGE模型最小化margin-based ranking loss,给定一个三元组,成立时模型给一个低的分数,不成立模型给一个高分,γ 为margin;损失函数和目标函数分别为:
在这里插入图片描述
在这里插入图片描述
负样本集合指替换头实体或尾实体构成的集合:
在这里插入图片描述
训练过程,随机选取正样本,当正样本被选中,已有方法从实体集合随机选取实体替换头实体或尾实体,替换时设置了两种策略,“unif”等概率替换头实体或尾实体,“bern”根据bernoulli分布选择替换头实体还是尾实体。

基于GAN的框架

在这里插入图片描述

生成器

生成器目标:为判别器提供高质量负样本;需要注意的是,生成器和判别器的embeddings不同。并且每个关系有两个独立的embedding分别表示关系和关系的逆,对于每个正样本,生成器的输入是实体-关系的pair,具体定义为(z为表示替换头实体 or 尾实体,通过unifbern设置决定):
在这里插入图片描述
模型图里把实体和关系的向量拼起来传给两层的全连接神经网络(第一层使用ReLU,第二层使用Softmax),该网络和生成器的向量矩阵用来在整个实体集合上参数化概率分布,分布定义为:
在这里插入图片描述
由于生成器的输出值是离散的实体,所以作者使用了基于强化学习的梯度策略,使用判别器计算reward函数,定义为:
在这里插入图片描述
tanh激活函数使得:当生成的负样本和正样本没有zero loss存在时,reward为正,反之为负。
生成器的训练目标是最大化reward的期望:
在这里插入图片描述
生成器的更新:
在这里插入图片描述

判别器

右边的判别器使用生成器提供的负样本学习KG的表示,这部分的设计与KGE模型相同,训练目标为最小化margin-based ranking loss。

训练过程

在这里插入图片描述
每个epoch,在mini-batch上迭代训练集,固定判别器的参数,训练生成器;再次迭代训练集,固定生成器的参数,训练判别器;整个过程被认为是先通过生成器搜索更好的负样本,过滤更不可靠的负样本,再传递给判别器。

实验部分模型采用了FB15K, FB13, WN11, WN18数据集,规模如下:
在这里插入图片描述
在训练过程主要分为两种设置GAN-scratch, GAN-pretrain
GAN-scratch:生成器和判别器的参数随机初始化
GAN-pretrain:先用随机负样本训练模型,然后用学到的embedding初始化判别器的参数,生成器的参数仍随机初始化,整个模型的训练看作在KGE模型上的fine-tune。

实验结果

链接预测:
用来预测三元组中缺失的实体,评估模型学到的embedding。评估方式是对不完整三元组的候选实体进行排序,对每个测试三元组,我们用实体集的所有实体替换头实体或尾实体,然后计算每个测试三元组的score和它们被替换实体后构成的三元组的score,对score按降序排序。
最后用两个指标反映结果:
mean rank指正确实体的平均排序,hits@10表示正确实体排在前10位的比例,mean rank小或者hits@10大表示结果好。
Filter设置:一个被替换实体后(被打乱)的三元组也可能在KG中存在,把它们排在原始三元组前面是可以的,于是过滤掉在训练集、验证集、测试集存在的被打乱的三元组,避免对模型能力的低估。
在这里插入图片描述
三元组分类:
判断一个三元组是否存在,二元分类问题,数据集WN11, FB13,在验证集和测试集的每个正样本已经有一个负样本,负样本随机构造,替换的限制是当实体在数据集中出现在同一个位置(head or tail)时才能被作为用来替换的实体。
三元组分类对每个关系设置一个阈值,如果score低于阈值,三元组被认为是positive,反之negative。通过在验证集最大化分类准确率优化阈值。
在这里插入图片描述
作者发现GAN-pretrain比GAN-scratch的结果要好,认为原因一:预训练的模型已经在一些非最佳状态收敛,但由于zero loss问题而需要有提升;另一方面GAN-pretrain的搜索空间比GAN-scratch小,判别器的reward已经比较稳定,它允许生成器的策略网络直接搜索margin内的实体。

辅助实验:
可视化生成的高质量负样本:
在这里插入图片描述
判别器的训练区别正样本和高质量的负样本使得模型更好地学习实体。

未来的工作

框架泛化到其他问题的负样本生成,生成器采用复杂的模型。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值