一般来说,要想训练一个模型,大致分为四个步骤:数据处理、模型构建、训练模型、测试模型,接下来的内容也会从这四个步骤开始介绍。
数据处理
FB15K数据是从Freebase(http://www.freebase.com)抽取到的一系列三元组(同义词集,关系类型,三元组)。该数据集可以看作是3模张量,描述了同义集之间的三元关系。
总共包含了3种文件:
train.txt 36M
valid.txt 3.7K
test.txt 4.4M
首先我定义了一个config.py文件用来存放一些与数据集相关的路径。
class Config(object): def __init__(self): super() self.train_fb15k = "./datasets/fb15k/train.txt" # 训练集路径 self.test_fb15k = "./datasets/fb15k/test.txt" # 测试集路径 self.valid_fb15k = "./datasets/fb15k/valid.txt" # 验证集路径 self.entity2id_train_file = "./datasets/fb15k/entity2id_train.txt" # 训练集实体到索引的映射 self.relation2id_train_file = "./datasets/fb15k/relation2id_train.txt" # 训练集关系到索引的映射 self.entity2id_test_file = "./datasets/fb15k/entity2id_test.txt" # 测试集实体到索引的映射 self.relation2id_test_file = "./datasets/fb15k/relation2id_test.txt" # 测试集关系到索引的映射 self.entity2id_valid_file = "./datasets/fb15k/entity2id_valid.txt" # 验证集实体到索引的映射 self.relation2id_valid_file = "./datasets/fb15k/relation2id_valid.txt" # 验证集关系到索引的映射 self.entity_50dim_batch400 = "./datasets/fb15k/entity_50dim_batch400" # 400 batch, 实体embedding向量50维的训练结果 self.relation_50dim_batch400 = "./datasets/fb15k/relation_50dim_batch400" # 400 batch, 关系embedding向量50维的训练结果
然后定义了一个data_process.py文件来对数据集进行处理,该文件包含一个Datasets类用来处理数据。
import osfrom config import Configclass Datasets(object): def __init__(self, config): super() self.config = config self.entity2id = {} self.relation2id = {} def load_data(self, file_path): ''' 加载数据 :param file_path: 数据的文件路径 :return: 读取的数据, 按行划分构成list ''' with open(file_path, "r", encoding="utf-8") as f: lines = f.readlines() return lines def build_data2id(self, is_test=False): ''' 将数据从字符串转换为index索引, 并保存到对应的路径 :param is_test: 是否是测试集 :return: null ''' # load data lines = [] if not is_test: lines = self.load_data(self.config.train_fb15k) print("load train data completely.") else: lines = self.load_data(self.config.test_fb15k) print("load test data completely.") # process 1 line idx_e = 0 idx_r = 0 for line in lines: line = line.strip().split("\t") self.entity2id.setdefault(line[0], idx_e) idx_e += 1 self.entity2id.setdefault(line[2], idx_e) idx_e += 1 self.relation2id.setdefault(line[1], idx_r) idx_r += 1 # save entity2id if not os.path.exists(self.config.entity2id_train_file): with open(self.config.entity2id_train_file, "a+", encoding="utf-8") as f: for k, v in self.entity2id.items(): entry = k + " " + str(v) + "\n" f.write(entry) # save relation2id if not os.path.exists(self.config.relation2id_train_file): with open(self.config.relation2id_train_file, "a+", encoding="utf-8") as f: for k, v in self.relation2id.items(): entry = k + " " + str(v) + "\n" f.write(entry) def build_data(self): ''' 将字符型数据转换为由one-hot编码表示的数据 :return: entity_set: 实体集 relation_set: 关系集 triple_list: 三元组列表 ''' # save entities entity_set = set() # save relations relation_set = set() # save triples triple_list = [] # load data lines = self.load_data(self.config.train_fb15k) # build data for line in lines: triple = line.strip().split("\t") # h, r, t of a triple h_ = self.entity2id[triple[0]] r_ = self.relation2id[triple[1]] t_ = self.entity2id[triple[2]] entity_set.add(h_) entity_set.add(t_) relation_set.add(r_) triple_list.append([h_, r_, t_]) return entity_set, relation_set, triple_list
然后可以在该脚本上看看数据的处理过程。
原数据可以通过load_data方法读取,并将其输出
config = Config()datasets = Datasets(config)lines = datasets.load_data(config.train_fb15k)print(lines[:1])
最后在控制台会得到这样的结果
上图中的内容实际上是一个三元组,其格式为(头实体,关系,尾实体),只不过实体和关系之间用/t隔开,因此,标准的格式应该是(/m/027rn,/location/country/form_of_government,/m/06cx9),这是知识图谱的RDF表示形式。知道了读取的数据格式,就可以对其进行处理,首先,计算机是很难处理字符型数据的,此外字符类型数据作为存储的话会占用比较大的内存空间,因此需要将这种字符类型的数据转换为one-hot编码,这里采用索引的方式,通过调用build_data2id方法来完成。
datasets.build_data2id()print("entity to index:")print(list(datasets.entity2id.items())[:1])print("relation to index:")print(list(datasets.relation2id.items())[:1])
上面是输出了构建好的实体到索引映射的第一条和关系到索引映射的第一条,然后在控制台就可以看到输出的结果:
前面说的,计算机不好直接处理字符类型,所以我们需要把输入模型的数据也转换为这种索引类型的,以下面这条三元组为例:
转换后的结果就变为,表示索引为0的实体与索引为1的实体存在着索引0的关系。这种原始数据到索引表示的数据的处理是通过build_data方法来实现的。
entity_set, relation_set, triple_list = datasets.build_data()
这样子就构建好了模型需要的数据格式。
模型构建
这部分