TransE
知识图谱基础
三元组(h,r,t)
知识表示
即将实体和关系向量化,embedding
算法描述
思想:一个正确的三元组的embedding会满足:h+r=t
定义距离d表示向量之间的距离,一般取L1或者L2,期望正确的三元组的距离越小越好,而错误的三元组的距离越大越好。为此给出目标函数为:
梯度求解:
代码分析
- 定义类:
参数:
目标函数的常数——margin
学习率——learningRate
向量维度——dim
实体列表——entityList(读取文本文件,实体+id)
关系列表——relationList(读取文本文件,关系 + id)
三元关系列表——tripleList(读取文本文件,实体 + 实体 + 关系)
损失值——loss
距离公式——L1
- 向量初始化
规定初始化维度和取值范围(TransE算法原理中的取值范围)
涉及的函数:
init:随机生成值
norm:归一化
- 训练向量
getSample——随机选取部分三元关系,Sbatch
getCorruptedTriplet(sbatch)——随机替换三元组的实体,h、t中任意一个被替换,但不同时替换。
update——更新
L2更新向量的推导过程:
python 函数
uniform(a, b)#随机生成a,b之间的数,左闭右开。
求向量的模,var = linalg.norm(list)
"""
@version: 3.7
@author: jiayalu
@file: trainTransE.py
@time: 22/08/2019 10:56
@description: 用于对知识图谱中的实体、关系基于TransE算法训练获取向量
数据:三元关系
实体id和关系id
结果为:两个文本文件,即entityVector.txt和relationVector.txt 实体 [array向量]
"""
from random import uniform, sample
from numpy import *
from copy import deepcopy
class TransE:
def __init__(self, entityList, relationList, tripleList, margin = 1, learingRate = 0.00001, dim = 10, L1 = True):
self.margin = margin
self.learingRate = learingRate
self.dim = dim#向量维度
self.entityList = entityList#一开始,entityList是entity的list;初始化后,变为字典,key是entity,values是其向量(使用narray)。
self.relationList = relationList#理由同上
self.tripleList = tripleList#理由同上
self.loss = 0
self.L1 = L1
def initialize(self):
'''
初始化向量
'''
entityVectorList = {
}
relationVectorList = {
}
for entity in self.entityList:
n = 0
entityVector = []
while n < self.dim:
ram = init(self.dim)#初始化的范围
entityVector.append(ram)
n += 1
entityVector = norm(entityVector)#归一化
entityVectorList[entity] = entityVector
print("entityVector初始化完成,数量是%d"%len(entityVectorList))
for relation in self. relationList:
n = 0
relationVector = []
while n < self.dim:
ram = init(self.dim)#初始化的范围
relationVector.append(ram)
n += 1
relationVector = norm(relationVector)#归一化
relationVectorList[relation] = relationVector
print("relationVectorList初始化完成,数量是%d"%len(relationVectorList))
self.entityList = entityVectorList
self.relationList = relationVectorList
def transE(self, cI = 20):
print("训练开始")
for cycleIndex in range(cI):
Sbatch = self.getSample(3)
T