ConvR:Adaptive Convolution for Multi-Relational Learning

关于卷积的知识图谱补全:
ConvE:Convolutional 2D Knowledge Graph Embeddings
ConvKB代码:A Novel Embedding Model for Knowledge Base Completion Based on Convolutional Neural Network
ConvR论文:点击

1 介绍

本论文的思想是在ConvE论文的基础上进行改进,进行知识图谱补全。ConvE成功实现实体和关系之间的交互,但是实体和关系之间的交互非常少,仅仅只有百分之20左右,因此提出此篇论文,自适应卷积,增加实体和关系之间的交互能力。

2 模型

2.1 模型图

在这里插入图片描述

ConvE的思想

1、仅仅只是简单的对s和r的embedding进行堆叠,他们之间的交互内容,只有一部分,大约百分之20%有,如图a中卷积之后,黄蓝交接的部分
2、采用的卷积为global为全局的卷积,在实现时,所以数据采用同一套卷积

ConvR的思想

1、其思想是对ConvE的问题进行改进,为了实现更好的交互,将r的embedding ,进行拆封,拆分成c个 h × w h\times w h×w的卷积核,r的embedding 维度为 c h w chw chw.
2、将s作为被卷积的对象,每一个s尤其专属的卷积核,通过这样的方法,实体s和关系r之间可以充分进行交互。
3、采用这种方式可以使模型的参数减少。

2.2 交互讲解

在这里插入图片描述

2.3 损失函数

在这里插入图片描述

3 代码

代码:来源

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from .BaseModel import BaseModel


class ConvR(BaseModel):
    def __init__(self, config):
        super(ConvR, self).__init__(config)
        self.device = config.get('device')
        self.entity_cnt = config.get('entity_cnt')
        self.relation_cnt = config.get('relation_cnt')
        kwargs = config.get('model_hyper_params')
        self.conv_out_channels = kwargs.get('conv_out_channels')#100
        self.reshape = kwargs.get('reshape')#[10, 10],
        self.kernel_size = kwargs.get('conv_kernel_size') [5, 5]
        self.stride = kwargs.get('stride')
        self.emb_dim = {
            'entity': kwargs.get('emb_dim'),#200
            'relation': self.conv_out_channels * self.kernel_size[0] * self.kernel_size[1]
        }
        assert self.emb_dim['entity'] == self.reshape[0] * self.reshape[1]
        self.E = torch.nn.Embedding(self.entity_cnt, self.emb_dim['entity'])
        self.R = torch.nn.Embedding(self.relation_cnt, self.emb_dim['relation'])
        self.input_drop = torch.nn.Dropout(kwargs.get('input_dropout'))
        self.feature_map_drop = torch.nn.Dropout2d(kwargs.get('feature_map_dropout'))
        self.hidden_drop = torch.nn.Dropout(kwargs.get('hidden_dropout'))
        self.bn0 = torch.nn.BatchNorm2d(1)  # batch normalization over a 4D input
        self.bn1 = torch.nn.BatchNorm2d(self.conv_out_channels)
        self.bn2 = torch.nn.BatchNorm1d(self.emb_dim['entity'])
        self.bn3 = torch.nn.BatchNorm1d(self.emb_dim['relation'])
        self.register_parameter('b', Parameter(torch.zeros(self.entity_cnt)))
        self.filtered = [(self.reshape[0] - self.kernel_size[0]) // self.stride + 1,
                         (self.reshape[1] - self.kernel_size[1]) // self.stride + 1]
        fc_length = self.filtered[0] * self.filtered[1]
        self.fc = torch.nn.Linear(fc_length, self.emb_dim['entity'])
        self.loss = ConvRLoss(self.device, kwargs.get('label_smoothing'), self.entity_cnt)
        self.init()
    
    def init(self):
        torch.nn.init.xavier_normal_(self.E.weight.data)
        torch.nn.init.xavier_normal_(self.R.weight.data)
    
    def forward(self, batch_h, batch_r, batch_t=None):
        batch_size = batch_h.size(0)
        
        e1 = self.E(batch_h).view(-1, 1, *self.reshape)
        e1 = self.bn0(e1).view(1, -1, *self.reshape)
        e1 = self.input_drop(e1)

        r = self.R(batch_r)
        r = self.bn3(r)
        r = self.input_drop(r)
        r = r.view(-1, 1, *self.kernel_size)

        x = F.conv2d(e1, r, groups=batch_size)
        x = x.view(batch_size, self.conv_out_channels, *self.filtered)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.feature_map_drop(x)
        x = x.sum(dim=1)
        x = x.view(batch_size, -1)
        x = self.fc(x)
        x = self.hidden_drop(x)
        x = self.bn2(x)
        x = F.relu(x)

        x = torch.mm(x, self.E.weight.transpose(1, 0))
        x += self.b.expand_as(x)
        y = torch.sigmoid(x)
        return self.loss(y, batch_t), y

class ConvRLoss(BaseModel):
    def __init__(self, device, label_smoothing, entity_cnt):
        super().__init__()
        self.device = device
        self.loss = torch.nn.BCELoss(reduction='sum')
        self.label_smoothing = label_smoothing
        self.entity_cnt = entity_cnt
    
    def forward(self, batch_p, batch_t=None):
        batch_size = batch_p.shape[0]
        loss = None
        if batch_t is not None:
            batch_e = torch.zeros(batch_size, self.entity_cnt).to(self.device).scatter_(1, batch_t.view(-1, 1), 1)
            batch_e = (1.0 - self.label_smoothing) * batch_e + self.label_smoothing / self.entity_cnt
            loss =  self.loss(batch_p, batch_e) / batch_size
        return loss
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值