python实现排队论模型_python实现TransE模型

本文详细介绍了使用Python实现TransE模型的过程,包括数据处理、模型构建、训练模型等步骤。通过处理FB15K数据集,构建知识图谱的三元组并转换为one-hot编码,然后利用TransE算法进行训练,最终评估模型性能。
摘要由CSDN通过智能技术生成

一般来说,要想训练一个模型,大致分为四个步骤:数据处理、模型构建、训练模型、测试模型,接下来的内容也会从这四个步骤开始介绍。

数据处理

FB15K数据是从Freebase(http://www.freebase.com)抽取到的一系列三元组(同义词集,关系类型,三元组)。该数据集可以看作是3模张量,描述了同义集之间的三元关系。

总共包含了3种文件:

  • train.txt 36M

  • valid.txt 3.7K

  • test.txt 4.4M

首先我定义了一个config.py文件用来存放一些与数据集相关的路径。

class Config(object):    def __init__(self):        super()        self.train_fb15k = "./datasets/fb15k/train.txt"     # 训练集路径        self.test_fb15k = "./datasets/fb15k/test.txt"       # 测试集路径        self.valid_fb15k = "./datasets/fb15k/valid.txt"     # 验证集路径        self.entity2id_train_file = "./datasets/fb15k/entity2id_train.txt"      # 训练集实体到索引的映射        self.relation2id_train_file = "./datasets/fb15k/relation2id_train.txt"  # 训练集关系到索引的映射        self.entity2id_test_file = "./datasets/fb15k/entity2id_test.txt"        # 测试集实体到索引的映射        self.relation2id_test_file = "./datasets/fb15k/relation2id_test.txt"    # 测试集关系到索引的映射        self.entity2id_valid_file = "./datasets/fb15k/entity2id_valid.txt"      # 验证集实体到索引的映射        self.relation2id_valid_file = "./datasets/fb15k/relation2id_valid.txt"  # 验证集关系到索引的映射        self.entity_50dim_batch400 = "./datasets/fb15k/entity_50dim_batch400"   # 400 batch, 实体embedding向量50维的训练结果        self.relation_50dim_batch400 = "./datasets/fb15k/relation_50dim_batch400"   # 400 batch, 关系embedding向量50维的训练结果

然后定义了一个data_process.py文件来对数据集进行处理,该文件包含一个Datasets类用来处理数据。

import osfrom config import Configclass Datasets(object):    def __init__(self, config):        super()        self.config = config        self.entity2id = {}        self.relation2id = {}    def load_data(self, file_path):        '''        加载数据        :param file_path: 数据的文件路径        :return: 读取的数据, 按行划分构成list        '''        with open(file_path, "r", encoding="utf-8") as f:            lines = f.readlines()        return lines    def build_data2id(self, is_test=False):        '''        将数据从字符串转换为index索引, 并保存到对应的路径        :param is_test: 是否是测试集        :return: null        '''        # load data        lines = []        if not is_test:            lines = self.load_data(self.config.train_fb15k)            print("load train data completely.")        else:            lines = self.load_data(self.config.test_fb15k)            print("load test data completely.")        # process 1 line        idx_e = 0        idx_r = 0        for line in lines:            line = line.strip().split("\t")            self.entity2id.setdefault(line[0], idx_e)            idx_e += 1            self.entity2id.setdefault(line[2], idx_e)            idx_e += 1            self.relation2id.setdefault(line[1], idx_r)            idx_r += 1        # save entity2id        if not os.path.exists(self.config.entity2id_train_file):            with open(self.config.entity2id_train_file, "a+", encoding="utf-8") as f:                for k, v in self.entity2id.items():                    entry = k + " " + str(v) + "\n"                    f.write(entry)        # save relation2id        if not os.path.exists(self.config.relation2id_train_file):            with open(self.config.relation2id_train_file, "a+", encoding="utf-8") as f:                for k, v in self.relation2id.items():                    entry = k + " " + str(v) + "\n"                    f.write(entry)    def build_data(self):        '''        将字符型数据转换为由one-hot编码表示的数据        :return:            entity_set: 实体集            relation_set: 关系集            triple_list: 三元组列表        '''        # save entities        entity_set = set()        # save relations        relation_set = set()        # save triples        triple_list = []        # load data        lines = self.load_data(self.config.train_fb15k)        # build data        for line in lines:            triple = line.strip().split("\t")            # h, r, t of a triple            h_ = self.entity2id[triple[0]]            r_ = self.relation2id[triple[1]]            t_ = self.entity2id[triple[2]]            entity_set.add(h_)            entity_set.add(t_)            relation_set.add(r_)            triple_list.append([h_, r_, t_])        return entity_set, relation_set, triple_list

然后可以在该脚本上看看数据的处理过程。

原数据可以通过load_data方法读取,并将其输出

config = Config()datasets = Datasets(config)lines = datasets.load_data(config.train_fb15k)print(lines[:1])

最后在控制台会得到这样的结果

fde8ee79ebc66ee5ef03b5e043c250da.png

上图中的内容实际上是一个三元组,其格式为(头实体,关系,尾实体),只不过实体和关系之间用/t隔开,因此,标准的格式应该是(/m/027rn,/location/country/form_of_government,/m/06cx9),这是知识图谱的RDF表示形式。知道了读取的数据格式,就可以对其进行处理,首先,计算机是很难处理字符型数据的,此外字符类型数据作为存储的话会占用比较大的内存空间,因此需要将这种字符类型的数据转换为one-hot编码,这里采用索引的方式,通过调用build_data2id方法来完成。

datasets.build_data2id()print("entity to index:")print(list(datasets.entity2id.items())[:1])print("relation to index:")print(list(datasets.relation2id.items())[:1])

上面是输出了构建好的实体到索引映射的第一条和关系到索引映射的第一条,然后在控制台就可以看到输出的结果:

0bea70cc3ed96c689385afb858d94b0c.png

前面说的,计算机不好直接处理字符类型,所以我们需要把输入模型的数据也转换为这种索引类型的,以下面这条三元组为例:

fde8ee79ebc66ee5ef03b5e043c250da.png

转换后的结果就变为,表示索引为0的实体与索引为1的实体存在着索引0的关系。这种原始数据到索引表示的数据的处理是通过build_data方法来实现的。

entity_set, relation_set, triple_list = datasets.build_data()

这样子就构建好了模型需要的数据格式。

模型构建

这部分

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值