OpenKE-TransE代码阅读
TransE伪代码
输入:训练集S={(h,l,t)},实体集L,关系集L,margin值y,嵌入向量维度k
γ \gamma γ:边距超参数。作用是,d[正三元组]-d[负三元组],会得到一个负数,margin是一个正数,使得整体式子是一个正数。随着loss的减小,d[正三元组]-d[负三元组]负数会越来越小,当其绝对值超过margin时,整个式子会变成负数。但是loss只取正数,当得到负数时,整个式子置为0,所以正负三元组最大距离为margin。
margin代表正负样本之间的最大距离,有了margin不会让负样本的d变得无限大
输入:relation2id.txt , entity2id.txt , train2id.txt
relation2id.txt:
/location/country/form_of_government 0
/tv/tv_program/regular_cast./tv/regular_tv_appearance/actor 1
entity2id.txt:
/m/027rn 0
/m/06cx9 1
train2id.txt:
0 1 0
2 3 1
代码逻辑
TrainDataLoader:数据采样,调用C++函数库方法
# 迭代器
def __iter__(self):
if self.sampling_mode == "normal":
return TrainDataSampler(self.nbatches, self.sampling)
else:
return TrainDataSampler(self.nbatches, self.cross_sampling)
# 调用sampling方法
def sampling(self):
# 调用c++采样方法,传入的是地址
self.lib.sampling(
self.batch_h_addr,
self.batch_t_addr,
self.batch_r_addr,
self.batch_y_addr,
self.batch_size,
self.negative_ent,
self.negative_rel,
0,
self.filter,
0,
0
)
# 返回数据
return {
"batch_h": self.batch_h,
"batch_t": self.batch_t,
"batch_r": self.batch_r,
"batch_y": self.batch_y,
"mode": "normal"
}
sampling中调用c++库中的Base.cpp的方法
// 启动线程,调用了getBatch方法
extern "C"
void sampling(
INT *batch_h, // 地址
INT *batch_t,
INT *batch_r,
REAL *batch_y,
INT batchSize,
INT negRate = 1,
INT negRelRate = 0,
INT mode = 0,
bool filter_flag = true,
bool p = false,
bool val_loss = false
) {
pthread_t *pt = (pthread_t *)malloc(workThreads * sizeof(pthread_t));
Parameter *para = (Parameter *)malloc(workThreads * sizeof(Parameter));
for (INT threads = 0; threads < workThreads; threads++) {
para[threads].id = threads;
//...设置参数
....
}
for (INT threads = 0; threads < workThreads; threads++)
pthread_join(pt[threads], NULL);// 启动线程,调用了getBatch方法
free(pt);
free(para);
}
C++ getBatch方法,打乱三元组获得负样本
// sampling调用getBatch方法,打乱三元组
void* getBatch(void* con) {
Parameter *para = (Parameter *)(con);
INT id = para -> id;
// 获取参数,省略
bool p = para -> p;
bool val_loss = para -> val_loss;
INT mode = para -> mode;
bool filter_flag = para -> filter_flag;
INT lef, rig;
if (batchSize % workThreads == 0) {
le