源代码
源代码参考项目
数据说明
数据集采用FB15K,下面代码中的文件分别为:
file1:训练集,格式为(head,relation,tail)
例:
/m/027rn /location/country/form_of_government /m/06cx9
/m/017dcd /tv/tv_program/regular_cast./tv/regular_tv_appearance/actor /m/06v8s0
/m/07s9rl0 /media_common/netflix_genre/titles /m/0170z3
file2:entity2id.txt,格式为(entity,id)
例:
/m/06rf7 0
/m/0c94fn 1
/m/016ywr 2
file3:relation2id.txt,格式为(relation,id)
例:
/people/appointed_role/appointment./people/appointment/appointed_by 0
/location/statistical_region/rent50_2./measurement_unit/dated_money_value/currency 1
/tv/tv_series_episode/guest_stars./tv/tv_guest_role/actor 2
file4:验证集,格式为(head,relation,tail)
例:
/m/07pd_j /film/film/genre /m/02l7c8
/m/06wxw /location/location/time_zones /m/02fqwt
/m/0d4fqn /award/award_winner/awards_won./award/award_honor/award_winner /m/03wh8kl
代码解释
def dataloader(file1, file2, file3, file4):
print("load file...")
entity = []
relation = []
with open(file2, 'r') as f1, open(file3, 'r') as f2:
lines1 = f1.readlines()
lines2 = f2.readlines()
for line in lines1:
# 去除头尾空格 以'\t'作为分隔符
line = line.strip().split('\t')
if len(line) != 2:
continue
# line[0] entity
# line[1] id
# entities2id:{entity : id}
entities2id[line[0]] = line[1]
entity.append(int(line[1]))
for line in lines2:
line = line.strip().split('\t')
if len(line) != 2:
continue
# line[0] relation
# line[1] id
# relations2id:{relation : id}
relations2id[line[0]] = line[1]
relation.append(int(line[1]))
triple_list = []
relation_head = {}
relation_tail = {}
with codecs.open(file1, 'r') as f:
content = f.readlines()
for line in content:
triple = line.strip().split("\t")
if len(triple) != 3:
continue
# h_ : 头实体id
# r_ : 关系id
# t_ : 尾实体id
h_ = int(entities2id[triple[0]])
r_ = int(relations2id[triple[1]])
t_ = int(entities2id[triple[2]])
triple_list.append([h_, r_, t_])
# relation_head:{relation: {head_entity:num} }
# 关系对应的头实体及其数量(即该头实体对应尾实体的个数)
if r_ in relation_head:
if h_ in relation_head[r_]:
relation_head[r_][h_] += 1
else:
relation_head[r_][h_] = 1
else:
relation_head[r_] = {}
relation_head[r_][h_] = 1
# relation_tail:{relation: {tail_entity:num} }
# 关系对应的尾实体及其数量(即该尾实体对应头实体的个数)
if r_ in relation_tail:
if t_ in relation_tail[r_]:
relation_tail[r_][t_] += 1
else:
relation_tail[r_][t_] = 1
else:
relation_tail[r_] = {}
relation_tail[r_][t_] = 1
for r_ in relation_head:
sum1, sum2 = 0, 0
# sum1 计算有几种头实体
# sum2 计算头实体数对应的尾实体数
for head in relation_head[r_]:
sum1 += 1
sum2 += relation_head[r_][head]
# tph 该关系中,平均每个头实体对应的尾实体数
tph = sum2 / sum1
relation_tph[r_] = tph
for r_ in relation_tail:
sum1, sum2 = 0, 0
# sum1 计算有几种尾实体
# sum2 计算尾实体数对应的头实体数
for tail in relation_tail[r_]:
sum1 += 1
sum2 += relation_tail[r_][tail]
hpt = sum2 / sum1
relation_hpt[r_] = hpt
valid_triple_list = []
with codecs.open(file4, 'r') as f:
content = f.readlines()
for line in content:
triple = line.strip().split("\t")
if len(triple) != 3:
continue
h_ = int(entities2id[triple[0]])
r_ = int(relations2id[triple[1]])
t_ = int(entities2id[triple[2]])
valid_triple_list.append([h_, r_, t_])
print("Complete load. entity : %d , relation : %d , train triple : %d, , valid triple : %d" % (
len(entity), len(relation), len(triple_list), len(valid_triple_list)))
return entity, relation, triple_list, valid_triple_list
计算头实体(尾实体)对应的尾实体(头实体)数量tph(hpt),在构建负样本时会用到。
例如,在一个知识图谱中有10个实体和n个关系,其中一个关系2个头实体对应5个尾实体,则tph=2.5,hpt=0.4。替换头实体的负样本正确概率为 10 − 2 10 − 1 = 8 9 \frac{10-2}{10-1}=\frac{8}{9} 10−110−2=98,替换尾实体的负样本正确概率为 10 − 5 10 − 1 = 5 9 \frac{10-5}{10-1}=\frac{5}{9} 10−110−5=95。因此构建正确负样本时就要替换头实体。