TransE模型-数据预处理

TransE模型-数据预处理

源代码

源代码参考项目

数据说明

数据集采用FB15K,下面代码中的文件分别为:

file1:训练集,格式为(head,relation,tail)

例:

/m/027rn	/location/country/form_of_government	/m/06cx9
/m/017dcd	/tv/tv_program/regular_cast./tv/regular_tv_appearance/actor	/m/06v8s0
/m/07s9rl0	/media_common/netflix_genre/titles	/m/0170z3

file2:entity2id.txt,格式为(entity,id)

例:

/m/06rf7	0
/m/0c94fn	1
/m/016ywr	2

file3:relation2id.txt,格式为(relation,id)

例:

/people/appointed_role/appointment./people/appointment/appointed_by	0
/location/statistical_region/rent50_2./measurement_unit/dated_money_value/currency	1
/tv/tv_series_episode/guest_stars./tv/tv_guest_role/actor	2

file4:验证集,格式为(head,relation,tail)

例:

/m/07pd_j	/film/film/genre	/m/02l7c8
/m/06wxw	/location/location/time_zones	/m/02fqwt
/m/0d4fqn	/award/award_winner/awards_won./award/award_honor/award_winner	/m/03wh8kl

代码解释

def dataloader(file1, file2, file3, file4):
    print("load file...")
    entity = []
    relation = []
    with open(file2, 'r') as f1, open(file3, 'r') as f2:
        lines1 = f1.readlines()
        lines2 = f2.readlines()
        for line in lines1:
            # 去除头尾空格 以'\t'作为分隔符
            line = line.strip().split('\t')
            if len(line) != 2:
                continue
            # line[0] entity
            # line[1] id
            # entities2id:{entity : id}
            entities2id[line[0]] = line[1]
            entity.append(int(line[1]))

        for line in lines2:
            line = line.strip().split('\t')
            if len(line) != 2:
                continue
            # line[0] relation
            # line[1] id
            # relations2id:{relation : id}
            relations2id[line[0]] = line[1]
            relation.append(int(line[1]))

    triple_list = []
    relation_head = {}
    relation_tail = {}

    with codecs.open(file1, 'r') as f:
        content = f.readlines()
        for line in content:
            triple = line.strip().split("\t")
            if len(triple) != 3:
                continue
		   # h_ : 头实体id 
            # r_ : 关系id 
            # t_ : 尾实体id
            h_ = int(entities2id[triple[0]])
            r_ = int(relations2id[triple[1]])
            t_ = int(entities2id[triple[2]])

            triple_list.append([h_, r_, t_])
            # relation_head:{relation: {head_entity:num} }
            # 关系对应的头实体及其数量(即该头实体对应尾实体的个数)
            if r_ in relation_head:
                if h_ in relation_head[r_]:
                    relation_head[r_][h_] += 1
                else:
                    relation_head[r_][h_] = 1
            else:
                relation_head[r_] = {}
                relation_head[r_][h_] = 1
            # relation_tail:{relation: {tail_entity:num} }
            # 关系对应的尾实体及其数量(即该尾实体对应头实体的个数)
            if r_ in relation_tail:
                if t_ in relation_tail[r_]:
                    relation_tail[r_][t_] += 1
                else:
                    relation_tail[r_][t_] = 1
            else:
                relation_tail[r_] = {}
                relation_tail[r_][t_] = 1

    for r_ in relation_head:
        sum1, sum2 = 0, 0
        # sum1 计算有几种头实体
        # sum2 计算头实体数对应的尾实体数
        for head in relation_head[r_]:
            sum1 += 1
            sum2 += relation_head[r_][head]
        # tph 该关系中,平均每个头实体对应的尾实体数
        tph = sum2 / sum1
        relation_tph[r_] = tph

    for r_ in relation_tail:
        sum1, sum2 = 0, 0
        # sum1 计算有几种尾实体
        # sum2 计算尾实体数对应的头实体数
        for tail in relation_tail[r_]:
            sum1 += 1
            sum2 += relation_tail[r_][tail]
        hpt = sum2 / sum1
        relation_hpt[r_] = hpt

    valid_triple_list = []
    with codecs.open(file4, 'r') as f:
        content = f.readlines()
        for line in content:
            triple = line.strip().split("\t")
            if len(triple) != 3:
                continue

            h_ = int(entities2id[triple[0]])
            r_ = int(relations2id[triple[1]])
            t_ = int(entities2id[triple[2]])

            valid_triple_list.append([h_, r_, t_])

    print("Complete load. entity : %d , relation : %d , train triple : %d, , valid triple : %d" % (
    len(entity), len(relation), len(triple_list), len(valid_triple_list)))

    return entity, relation, triple_list, valid_triple_list

计算头实体(尾实体)对应的尾实体(头实体)数量tph(hpt),在构建负样本时会用到。

例如,在一个知识图谱中有10个实体和n个关系,其中一个关系2个头实体对应5个尾实体,则tph=2.5,hpt=0.4。替换头实体的负样本正确概率为 10 − 2 10 − 1 = 8 9 \frac{10-2}{10-1}=\frac{8}{9} 101102=98,替换尾实体的负样本正确概率为 10 − 5 10 − 1 = 5 9 \frac{10-5}{10-1}=\frac{5}{9} 101105=95。因此构建正确负样本时就要替换头实体。

  • 1
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值