TuckER:Tensor Factorization for Knowledge Graph Completion

论文来源:ICML2019

论文链接:https://arxiv.org/abs/1901.09590

代码链接:https://github.com/ibalazevic/TuckER

  • 总结:这篇文章利用了一个高级的公式,可能很多人看到这个公式就怕了,确实如此,我看到这个公式就晕了,不知道其中的具体含义,其实如果你实在搞不懂具体的意思,也可以忽略,大致理解作者的想法就行了,我是在看完一遍paper之后,然后去看了代码,于是对这篇文章的训练过程就有了一定的了解,然后又将paper看了一遍,大家也可以采用这种学习方法去阅读paper。

1. 背景知识

1.1 Tucker Decomposition

  • 定义:将一个张量分解为一组矩阵一个核张量(core tensor)

  • 公式如下:
    X ≈ Z × 1 A × 2 B × 3 C \mathcal{X} \approx \mathcal{Z} \times_{1} \mathbf{A} \times{ }_{2} \mathbf{B} \times{ }_{3} \mathbf{C} XZ×1A×2B×3C
    其中: X ∈ R I × J × K \mathcal{X} \in \mathbb{R}^{I \times J \times K } XRI×J×K Z ∈ R P × Q × R \mathcal{Z} \in \mathbb{R}^{P \times Q \times R} ZRP×Q×R A ∈ R I × P A \in \mathbb{R}^{I \times P} ARI×P B ∈ R J × Q B \in \mathbb{R}^{J \times Q} BRJ×Q C ∈ R K × R C \in \mathbb{R}^{K \times R} CRK×R

  • 公式解释:

    • × n \times_{n} ×n表示沿着模式n的张量乘积(可以简单理解为一种计算公式的简化写法)
    • ABC可以理解为每种模式下的主成分。
    • Z \mathcal{Z} Z中的每个元素代表了不同成分之间的交互程度。
    • 其中 PQR 分别小于 IJK,所以也可以认为 Z \mathcal{Z} Z X \mathcal{X} X的压缩版本

2. 模型架构

在这里插入图片描述

  • 评分函数:
    ϕ ( e s , r , e o ) = W × 1 e s × 2 w r × 3 e o \phi\left(e_{s}, r, e_{o}\right)=\mathcal{W} \times_{1} \mathbf{e}_{s} \times_{2} \mathbf{w}_{r} \times_{3} \mathbf{e}_{o} ϕ(es,r,eo)=W×1es×2wr×3eo
    其中: e s 、 w r 、 e o \mathbf{e}_{s}、\mathbf{w}_{r}、\mathbf{e}_{o} eswreo表示头实体、关系和尾实体的嵌入; d e 、 d r d_e、d_r dedr分别表示实体和关系的嵌入维度; W ∈ R d e × w r × d e \mathcal{W} \in \mathbb{R}^{d_e \times w_r \times d_e} WRde×wr×de表示核张量

3. 模型训练

  • 将上述评分函数得到的结果输入到sigmoid函数中,会得到一个概率值,然后计算下面的损失值:
    L = − 1 n e ∑ i = 1 n e ( y ( i ) log ⁡ ( p ( i ) ) + ( 1 − y ( i ) ) log ⁡ ( 1 − p ( i ) ) ) L=-\frac{1}{n_{e}} \sum_{i=1}^{n_{e}}\left(\mathbf{y}^{(i)} \log \left(\mathbf{p}^{(i)}\right)+\left(1-\mathbf{y}^{(i)}\right) \log \left(1-\mathbf{p}^{(i)}\right)\right) L=ne1i=1ne(y(i)log(p(i))+(1y(i))log(1p(i)))

  • 在看代码之前,我认为这个评分函数是对整个三元组的评分,实际上代码中并不是这个意思

  • 前向传播计算输出的是,一个矩阵,大小为(batch,len(entity)),输入的是一个batch的头实体和关系,相当于是在预测尾实体出现在每个位置的概率值。然后将这个概率值同目标位置组成的矩阵(正确位置为1)计算loss(可以理解为二分类问题)

4. 核心代码以及解释

class TuckER(torch.nn.Module):
    def __init__(self, d, d1, d2, **kwargs):
        '''
        :param d: 数据集
        :param d1: 实体嵌入维度 200
        :param d2: 关系嵌入维度 200
        :param kwargs: 字典
        '''
        super(TuckER, self).__init__()
        self.E = torch.nn.Embedding(len(d.entities), d1)
        self.R = torch.nn.Embedding(len(d.relations), d2)
        self.W = torch.nn.Parameter(torch.tensor(np.random.uniform(-1, 1, (d2, d1, d1)),
                                                 dtype=torch.float, device="cuda", requires_grad=True))
        self.input_dropout = torch.nn.Dropout(kwargs["input_dropout"])
        self.hidden_dropout1 = torch.nn.Dropout(kwargs["hidden_dropout1"])
        self.hidden_dropout2 = torch.nn.Dropout(kwargs["hidden_dropout2"])
        self.loss = torch.nn.BCELoss()	# 类似于二分类的交叉熵损失
        self.bn0 = torch.nn.BatchNorm1d(d1)
        self.bn1 = torch.nn.BatchNorm1d(d1)
        torch.nn.init.xavier_normal_(self.E.weight.data)
        torch.nn.init.xavier_normal_(self.R.weight.data)

    def forward(self, e1_idx, r_idx):
        '''简单理解: pred=e1*r*W*E'''
        e1 = self.E(e1_idx)
        x = self.bn0(e1)
        x = self.input_dropout(x)  # [128,200]
        x = x.view(-1, 1, e1.size(1))  # [128,1,200]
		
        r = self.R(r_idx)
        W_mat = torch.mm(r, self.W.view(r.size(1), -1))  # [128,40000]
        W_mat = W_mat.view(-1, e1.size(1), e1.size(1))  # [128,200,200]
        W_mat = self.hidden_dropout1(W_mat)

        x = torch.bmm(x, W_mat)  # [128,1,200]
        x = x.view(-1, e1.size(1))  # [128,200]
        x = self.bn1(x)
        x = self.hidden_dropout2(x)
        x = torch.mm(x, self.E.weight.transpose(1, 0))  # transpose相当于转置 [128,14541]
        pred = torch.sigmoid(x)
        return pred
  • 模型分析:

    • 这个模型中要训练的参数一共有13个

      W
      E.weight
      R.weight
      bn0.weight
      bn0.bias
      bn0.running_mean
      bn0.running_var
      bn0.num_batches_tracked
      bn1.weight
      bn1.bias
      bn1.running_mean
      bn1.running_var
      bn1.num_batches_tracked
      
    • 训练过程中加入了反关系的三元组,例如 ( h , r , t ) (h,r,t) (h,r,t) ( t , r − 1 , h ) (t,r^{-1},h) (t,r1,h)

我是一个科研小白,努力分享自己认为有价值的paper~~

  • 5
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值