TransE测试代码

文章介绍了如何使用Python实现基于TransE模型的数据加载、距离计算和测试过程,包括实体嵌入、关系嵌入的处理,以及对测试数据进行hits@10和meanrank的评估。
摘要由CSDN通过智能技术生成
import numpy as np
import codecs
import operator
import json
from transE import data_loader,entity2id,relation2id

def dataloader(entity_file,relation_file,test_file):
    # entity_file: entity \t embedding
    entity_dict = {}
    relation_dict = {}
    test_triple = []

    with codecs.open(entity_file) as e_f:
        lines = e_f.readlines()
        for line in lines:
            entity,embedding = line.strip().split('\t')
            embedding = json.loads(embedding)
            entity_dict[entity] = embedding

    with codecs.open(relation_file) as r_f:
        lines = r_f.readlines()
        for line in lines:
            relation,embedding = line.strip().split('\t')
            embedding = json.loads(embedding)
            relation_dict[relation] = embedding

    with codecs.open(test_file) as t_f:
        lines = t_f.readlines()
        for line in lines:
            triple = line.strip().split('\t')
            if len(triple) != 3:
                continue
            h_ = entity2id[triple[0]]
            t_ = entity2id[triple[1]]
            r_ = relation2id[triple[2]]

            test_triple.append(tuple((h_,t_,r_)))

    return entity_dict,relation_dict,test_triple

def distance(h,r,t):
    h = np.array(h)
    r=np.array(r)
    t = np.array(t)
    s=h+r-t
    return np.linalg.norm(s)

class Test:
    def __init__(self,entity_dict,relation_dict,test_triple,train_triple,isFit = True):
        self.entity_dict = entity_dict
        self.relation_dict = relation_dict
        self.test_triple = test_triple
        self.train_triple = train_triple
        self.isFit = isFit

        self.hits10 = 0
        self.mean_rank = 0

        self.relation_hits10 = 0
        self.relation_mean_rank = 0

    def rank(self):
        hits = 0
        rank_sum = 0
        step = 1

        for triple in self.test_triple:
            rank_head_dict = {}
            rank_tail_dict = {}

            for entity in self.entity_dict.keys():
                corrupted_head = [entity,triple[1],triple[2]]
                if self.isFit:
                    if corrupted_head not in self.train_triple:
                        h_emb = self.entity_dict[corrupted_head[0]]
                        r_emb = self.relation_dict[corrupted_head[2]]
                        t_emb = self.entity_dict[corrupted_head[1]]
                        rank_head_dict[tuple(corrupted_head)]=distance(h_emb,r_emb,t_emb)
                else:
                    h_emb = self.entity_dict[corrupted_head[0]]
                    r_emb = self.relation_dict[corrupted_head[2]]
                    t_emb = self.entity_dict[corrupted_head[1]]
                    rank_head_dict[tuple(corrupted_head)] = distance(h_emb, r_emb, t_emb)

                corrupted_tail = [triple[0],entity,triple[2]]
                if self.isFit:
                    if corrupted_tail not in self.train_triple:
                        h_emb = self.entity_dict[corrupted_tail[0]]
                        r_emb = self.relation_dict[corrupted_tail[2]]
                        t_emb = self.entity_dict[corrupted_tail[1]]
                        rank_tail_dict[tuple(corrupted_tail)] = distance(h_emb, r_emb, t_emb)
                else:
                    h_emb = self.entity_dict[corrupted_tail[0]]
                    r_emb = self.relation_dict[corrupted_tail[2]]
                    t_emb = self.entity_dict[corrupted_tail[1]]
                    rank_tail_dict[tuple(corrupted_tail)] = distance(h_emb, r_emb, t_emb)

            rank_head_sorted = sorted(rank_head_dict.items(),key = operator.itemgetter(1))
            rank_tail_sorted = sorted(rank_tail_dict.items(),key = operator.itemgetter(1))

            #rank_sum and hits
            for i in range(len(rank_head_sorted)):
                if triple[0] == rank_head_sorted[i][0][0]:
                    if i<10:
                        hits += 1
                    rank_sum = rank_sum + i + 1
                    break

            for i in range(len(rank_tail_sorted)):
                if triple[1] == rank_tail_sorted[i][0][1]:
                    if i<10:
                        hits += 1
                    rank_sum = rank_sum + i + 1
                    break

            step += 1
            if step % 5000 == 0:
                print("step ", step, " ,hits ",hits," ,rank_sum ",rank_sum)
                print()

        self.hits10 = hits / (2*len(self.test_triple))
        self.mean_rank = rank_sum / (2*len(self.test_triple))

    def relation_rank(self):
        hits = 0
        rank_sum = 0
        step = 1

        for triple in self.test_triple:
            rank_dict = {}
            for r in self.relation_dict.keys():
                corrupted_relation = (triple[0],triple[1],r)
                if self.isFit and corrupted_relation in self.train_triple:
                    continue
                h_emb = self.entity_dict[corrupted_relation[0]]
                r_emb = self.relation_dict[corrupted_relation[2]]
                t_emb = self.entity_dict[corrupted_relation[1]]
                rank_dict[r]=distance(h_emb, r_emb, t_emb)

            rank_sorted = sorted(rank_dict.items(),key = operator.itemgetter(1))

            rank = 1
            for i in rank_sorted:
                if triple[2] == i[0]:
                    break
                rank += 1
            if rank<10:
                hits += 1
            rank_sum = rank_sum + rank + 1

            step += 1
            if step % 5000 == 0:
                print("relation step ", step, " ,hits ", hits, " ,rank_sum ", rank_sum)
                print()

        self.relation_hits10 = hits / len(self.test_triple)
        self.relation_mean_rank = rank_sum / len(self.test_triple)

if __name__ == '__main__':
    _, _, train_triple = data_loader("FB15k\\")

    entity_dict, relation_dict, test_triple = \
        dataloader("entity_50dim_batch400","relation50dim_batch400",
                   "FB15k\\test.txt")


    test = Test(entity_dict,relation_dict,test_triple,train_triple,isFit=False)
    test.rank()
    print("entity hits@10: ", test.hits10)
    print("entity meanrank: ", test.mean_rank)

    test.relation_rank()
    print("relation hits@10: ", test.relation_hits10)
    print("relation meanrank: ", test.relation_mean_rank)

    f = open("result.txt",'w')
    f.write("entity hits@10: "+ str(test.hits10) + '\n')
    f.write("entity meanrank: " + str(test.mean_rank) + '\n')
    f.write("relation hits@10: " + str(test.relation_hits10) + '\n')
    f.write("relation meanrank: " + str(test.relation_mean_rank) + '\n')
    f.close()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值