知识图谱补全技术-DistMult篇

知识图谱补全技术-DistMult篇



前言

在自然语言处理和机器学习领域,知识图谱是一种至关重要的数据结构。知识图谱通过节点表示实体,边表示实体之间的关系,构建了一个复杂的网络结构。在知识图谱中,如何高效地表示实体和关系是一个关键问题。为了解决这个问题,知识图谱补全技术应运而生。它通过预测和填补知识图谱中的缺失三元组,增强了数据的完整性和实用性。知识图谱补全技术利用嵌入模型和图神经网络等先进的机器学习方法,广泛应用于搜索引擎、推荐系统和智能问答系统等领域,从而显著提升了这些系统的性能和用户体验。

DistMult模型是知识图谱补全模型中的经典方法之一。DistMult通过一种对称的双线性模型来学习实体和关系的向量表示,即它假设所有关系都是对称的,这虽然简化了模型,但也限制了它对反对称关系的处理能力。尽管如此,DistMult在许多实际应用中仍表现出色。本文将介绍如何对知识图谱数据进行预处理,以及如何使用Python和PyTorch实现DistMult模型的训练和评估。


一、DistMult模型原理

DistMult模型是一种用于知识图谱补全的嵌入模型,其主要思想是通过双线性模型来表示实体和关系之间的相互作用。与其他知识图谱补全模型相比,DistMult模型具有计算简单、效果显著的特点。

1.模型假设

DistMult模型假设实体和关系都可以嵌入到相同的低维向量空间中。具体来说,给定一个知识图谱中的三元组 (ℎ,𝑟,𝑡)其中 ℎ是头实体,𝑟 是关系,𝑡是尾实体,DistMult模型通过以下方式来表示该三元组的得分:
在这里插入图片描述
这里,h、r 和 t 分别是头实体、关系和尾实体的向量表示,⟨⋅,⋅,⋅⟩表示向量之间的逐元素乘积和。

2.得分函数

DistMult模型的得分函数定义为:
在这里插入图片描述
其中,𝑑是向量的维度,ℎ𝑖、𝑟𝑖 和 𝑡𝑖分别是向量 h、r 和 t 在第 𝑖 维的分量。通过这个得分函数,DistMult模型可以评估一个三元组的合理性:得分越高,三元组越有可能是合理的。

3.模型训练

DistMult模型的训练目标是最大化正确三元组的得分,同时最小化错误三元组的得分。通常使用负采样技术来生成错误的三元组(即负例),并使用二元交叉熵损失函数来进行优化。具体的训练步骤如下:
正例和负例采样:对于每个正例
(ℎ,𝑟,𝑡),生成若干个负例 (ℎ′,𝑟,𝑡)(h ′ ,r,t) 或 (ℎ,𝑟,𝑡′)(h,r,t ′ ),其中 ℎ′ 和 𝑡′ 是随机选择的错误实体。
损失函数:使用二元交叉熵损失函数来定义正例和负例的损失。对于一个三元组 (ℎ,𝑟,𝑡)及其标签 𝑦(正例为1,负例为0),损失函数为:
在这里插入图片描述
其中,𝜎是 sigmoid 函数。
优化:使用梯度下降算法对模型参数进行优化,以最小化损失函数。

4.模型优势与局限

优势:
计算简单:DistMult模型的计算复杂度较低,适合大规模知识图谱。
效果显著:尽管假设关系是对称的,DistMult在许多实际任务中表现良好。

局限:
对称性限制:由于模型假设关系是对称的,因此对反对称关系(如“父母”与“子女”)的处理能力有限。
模型简单:相比于更复杂的模型(如ComplEx、RotatE),DistMult的表达能力较弱。

二、DistMult算法

1.准备好三元组数据集csv格式

csv文件内容三元组格式如图
请添加图片描述

2.安装必要的库

安装了以下Python包:
pandas
scikit-learn
torch

3.运行步骤

确保所有文件在同一目录下;
编辑data_preprocessing.py、train_distmult.py和completion_distmult.py中的文件路径,确保使用正确的数据文件路径;
运行data_preprocessing.py检查数据列名是否正确;
运行train_distmult.py训练模型;
运行completion_distmult.py进行知识补全。

4.部分代码展示

data_preprocessing.py

import pandas as pd
from sklearn.model_selection import train_test_split
import torch


def load_and_preprocess_data(file_path):
    # 加载数据
    data = pd.read_csv(file_path)

    # 打印列名以检查实际列名
    print("Columns in CSV file:", data.columns)

    # 创建实体和关系的索引映射
    entities = pd.concat([data['头实体'], data['尾实体']]).unique()
    relations = data['关系'].unique()
    entity_to_id = {entity: idx for idx, entity in enumerate(entities)}
    relation_to_id = {relation: idx for idx, relation in enumerate(relations)}

    # 将实体和关系映射为索引
    data['头实体'] = data['头实体'].map(entity_to_id)
    data['关系'] = data['关系'].map(relation_to_id)
    data['尾实体'] = data['尾实体'].map(entity_to_id)

    # 划分训练集、验证集和测试集
    train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
    train_data, valid_data = train_test_split(train_data, test_size=0.2, random_state=42)

    # 转换为PyTorch张量
    train_triples = torch.LongTensor(train_data.values)
    valid_triples = torch.LongTensor(valid_data.values)
    test_triples = torch.LongTensor(test_data.values)
      .......

代码解释:加载和预处理数据:读取CSV文件中的三元组数据,并将实体和关系映射为唯一的整数索引。然后,将数据集划分为训练集、验证集和测试集,并转换为PyTorch张量。

model_distmult.py

import torch
import torch.nn as nn

class DistMult(nn.Module):
    def __init__(self, num_entities, num_relations, embedding_dim):
        super(DistMult, self).__init__()
        self.embedding_dim = embedding_dim
        self.entity_embeddings = nn.Embedding(num_entities, embedding_dim)
        self.relation_embeddings = nn.Embedding(num_relations, embedding_dim)
        self.entity_embeddings.weight.data.uniform_(-6 / embedding_dim**0.5, 6 / embedding_dim**0.5)
        self.relation_embeddings.weight.data.uniform_(-6 / embedding_dim**0.5, 6 / embedding_dim**0.5)

    def forward(self, head, relation, tail):
        h = self.entity_embeddings(head)
        r = self.relation_embeddings(relation)
        t = self.entity_embeddings(tail)
        score = torch.sum(h * r * t, dim=1)
        return score
        .......

代码解释:定义DistMult模型:该模块定义了DistMult模型的结构,定义了如何将实体和关系嵌入到低维向量空间,并计算三元组的得分。

train_distmult.py

import torch
import torch.optim as optim
import numpy as np
from model_distmult import DistMult  # 请确保 model_distmult.py 文件在同一目录下
from data_preprocessing import load_and_preprocess_data  # 请确保 data_preprocessing.py 文件在同一目录下

# 加载数据
file_path = '/..../3_三元组数据集.csv'  # 替换为你的文件路径
train_triples, valid_triples, test_triples, entity_to_id, relation_to_id = load_and_preprocess_data(file_path)

# 初始化模型参数
num_entities = len(entity_to_id)
num_relations = len(relation_to_id)
embedding_dim = 100

# 创建模型和优化器
model = DistMult(num_entities, num_relations, embedding_dim)
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()

    # 随机选择负样本
    neg_head = torch.randint(0, num_entities, (train_triples.size(0),))
    neg_tail = torch.randint(0, num_entities, (train_triples.size(0),))
    neg_triples = torch.stack((neg_head, train_triples[:, 1], neg_tail), dim=1)

    # 计算正样本和负样本的分数
    pos_score = model(train_triples[:, 0], train_triples[:, 1], train_triples[:, 2])
    neg_score = model(neg_triples[:, 0], neg_triples[:, 1], neg_triples[:, 2])

    # 计算损失并更新模型参数
    loss = model.loss(pos_score, neg_score)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

# 保存模型
torch.save(model.state_dict(), 'distmult_model.pth')

# 评估模型
def evaluate_model(model, data, k_values=[1, 3, 10]):
    mean_rank = 0
    hits_at_k = {k: 0 for k in k_values}
    total_triples = data.shape[0]

    for i in range(total_triples):
        head, relation, tail = data[i]
        head_idx = torch.LongTensor([head])
        relation_idx = torch.LongTensor([relation])
        tail_idx = torch.LongTensor([tail])

        all_entities = torch.LongTensor(range(num_entities))

        # 计算所有可能的尾实体分数
        scores_tail = model(head_idx, relation_idx, all_entities).detach().cpu().numpy()
        rank_tail = np.argsort(np.argsort(scores_tail))[tail]

        # 计算所有可能的头实体分数
        scores_head = model(all_entities, relation_idx, tail_idx).detach().cpu().numpy()
        rank_head = np.argsort(np.argsort(scores_head))[head]

        # 更新Mean Rank
        mean_rank += (rank_tail + rank_head) / 2

        # 更新Hits@N
        for k in k_values:
            if rank_tail < k:
                hits_at_k[k] += 1
            if rank_head < k:
                hits_at_k[k] += 1

    mean_rank /= total_triples
    hits_at_k = {k: v / total_triples for k, v in hits_at_k.items()}

    return mean_rank, hits_at_k
        .......

代码解释:训练distmult模型:读取数据,初始化模型和优化器,进行模型训练,并保存训练好的模型。同时定义了评估函数,用于在测试集上评估模型性能。
运行结果如图:
请添加图片描述

completion_distmult.py

import torch
from model_distmult import DistMult
from data_preprocessing import load_and_preprocess_data

# 加载数据和模型
file_path = '/home.../3_三元组数据集.csv'  # 替换为你的文件路径
train_triples, valid_triples, test_triples, entity_to_id, relation_to_id = load_and_preprocess_data(file_path)

# 初始化模型参数
num_entities = len(entity_to_id)
num_relations = len(relation_to_id)
embedding_dim = 100

# 创建和加载模型
model = DistMult(num_entities, num_relations, embedding_dim)
model.load_state_dict(torch.load('distmult_model.pth'))  # 使用实际的模型路径
model.eval()

# 知识补全函数
def predict_tail(head, relation, k=5):
    head_idx = torch.LongTensor([entity_to_id[head]])
    relation_idx = torch.LongTensor([relation_to_id[relation]])
    all_entities = torch.LongTensor(range(num_entities))

    scores = model(head_idx, relation_idx, all_entities)
    _, topk_indices = torch.topk(scores, k, largest=False)

    return [list(entity_to_id.keys())[idx.item()] for idx in topk_indices]

def predict_head(tail, relation, k=5):
    tail_idx = torch.LongTensor([entity_to_id[tail]])
    relation_idx = torch.LongTensor([relation_to_id[relation]])
    all_entities = torch.LongTensor(range(num_entities))
    .......

代码解释:知识补全和预测:加载预训练模型,并使用模型进行知识补全,预测给定头实体和关系的尾实体,或给定尾实体和关系的头实体。
运行结果如图:(数据保密,故作打码处理)
请添加图片描述


总结

DistMult模型作为知识图谱补全的经典方法,通过对称的双线性模型来嵌入实体和关系,尽管假设关系是对称的,但在实际应用中表现出色。其计算简单,适合大规模知识图谱的处理,广泛应用于搜索引擎、推荐系统等领域。然而,其对反对称关系的处理能力有限,需要与其他模型结合使用以提升效果。总体而言,DistMult模型凭借其简洁高效的设计,在知识图谱补全领域具有重要地位。

代码购买链接:https://mbd.pub/o/bread/ZpaUlpxq。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值