7.27 TransE_pytoch.py(1)

1.dataloader函数 

def dataloader(file1, file2, file3, file4):
    print("load file...")

    entity = []
    relation = []
    with open(file2, 'r') as f1, open(file3, 'r') as f2:
        lines1 = f1.readlines()
        lines2 = f2.readlines()
        for line in lines1:
            #去除头尾空格 以‘\t’作为分隔符
            line = line.strip().split('\t')
            if len(line) != 2:
                continue
            #line[0]  entity
            #line[1]  id
            #entities2id:{entity : id}
            entities2id[line[0]] = line[1]
            entity.append(int(line[1]))

        for line in lines2:
            line = line.strip().split('\t')
            if len(line) != 2:
                continue
            #line[0]  relation
            #line[1]  id
            #relations2id:{relation : id}
            relations2id[line[0]] = line[1]
            relation.append(int(line[1]))


    triple_list = []
    relation_head = {}
    relation_tail = {}

    with codecs.open(file1, 'r') as f:
        content = f.readlines()
        for line in content:
            triple = line.strip().split("\t")
            if len(triple) != 3:
                continue

            # h_:头实体id
            # r_:关系id
            # t_:尾实体id

            h_ = int(entities2id[triple[0]])
            r_ = int(relations2id[triple[1]])
            t_ = int(entities2id[triple[2]])


            triple_list.append([h_, r_, t_])
            # relation_head:{relation: {head_entity:num} }
            # 关系对应的头实体及其数量(即该头实体对应尾实体的个数)
            if r_ in relation_head:
                if h_ in relation_head[r_]:
                    relation_head[r_][h_] += 1
                else:
                    relation_head[r_][h_] = 1
            else:
                relation_head[r_] = {}
                relation_head[r_][h_] = 1

            # relation_tail:{relation: {tail_entity:num} }
            # 关系对应的尾实体及其数量(即该尾实体对应头实体的个数)
            if r_ in relation_tail:
                if t_ in relation_tail[r_]:
                    relation_tail[r_][t_] += 1
                else:
                    relation_tail[r_][t_] = 1
            else:
                relation_tail[r_] = {}
                relation_tail[r_][t_] = 1

    for r_ in relation_head:
        sum1, sum2 = 0, 0
        # sum1 计算有几种头实体
        # sum2 计算头实体数对应的尾实体数
        for head in relation_head[r_]:
            sum1 += 1
            sum2 += relation_head[r_][head]
        # tph 该关系中,平均每个头实体对应的尾实体数
        tph = sum2 / sum1
        relation_tph[r_] = tph

    for r_ in relation_tail:
        sum1, sum2 = 0, 0
        # sum1 计算有几种尾实体
        # sum2 计算尾实体数对应的头实体数
        for tail in relation_tail[r_]:
            sum1 += 1
            sum2 += relation_tail[r_][tail]
        hpt = sum2 / sum1
        relation_hpt[r_] = hpt

    valid_triple_list = []
    with codecs.open(file4, 'r') as f:
        content = f.readlines()
        for line in content:
            triple = line.strip().split("\t")
            if len(triple) != 3:
                continue

            h_ = int(entities2id[triple[0]])
            r_ = int(relations2id[triple[1]])
            t_ = int(entities2id[triple[2]])


            valid_triple_list.append([h_, r_, t_])

    print("Complete load. entity : %d , relation : %d , train triple : %d, , valid triple : %d" % (
    len(entity), len(relation), len(triple_list), len(valid_triple_list)))

    return entity, relation, triple_list, valid_triple_list

计算头实体(尾实体)对应的尾实体(头实体)数量tph(hpt),在构建负样本时会用到。

例如,在一个知识图谱中有10个实体和n个关系,其中一个关系2个头实体对应5个尾实体,则tph=2.5,hpt=0.4。替换头实体的负样本正确概率为,替换尾实体的负样本正确概率为。因此构建正确负样本时就要替换头实体。

2.__init__函数

其参数有:

  • entity_num:entity 的数量

  • relation_num:relation 的数量

  • dim:每个 embedding vector(嵌入向量)的维度

  • norm:在计算d(h+l , t)时是使用L1范数还是L2范数

  • margin:损失函数中的间隔,正负样本三元组之间的间隔修正

  • C:损失函数计算中的正则化项参数

class E(nn.Module):
    def __init__(self, entity_num, relation_num, dim, margin, norm, C):
        super(E, self).__init__()
        self.entity_num = entity_num
        self.relation_num = relation_num
        self.dim = dim
        self.margin = margin
        self.norm = norm
        self.C = C

        #初始化实体和关系表示向量
        self.ent_embedding = torch.nn.Embedding(num_embeddings=self.entity_num,
                                                          embedding_dim=self.dim).cuda()
        self.rel_embedding = torch.nn.Embedding(num_embeddings=self.relation_num,
                                                           embedding_dim=self.dim).cuda()
        
        #损失函数
        self.loss_F = nn.MarginRankingLoss(self.margin, reduction="mean").cuda()

        self.__data_init()

初始化 embedding matrix(嵌入矩阵)时,直接用 torch.nn.Embedding 来完成,参数分别是 entity 的数量和每个 embedding vector 的维数,这样得到的就是一个 entity_num * dim 大小的 Embedding Matrix。

 参考:

TransE模型-数据预处理

【KG】TransE及其实现

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值