<span style="font-size:18px;"></pre><pre name="code" class="cpp">#include<iostream>
#include<cstring>
#include<cstdio>
#include<map>
#include<vector>
#include<string>
#include<ctime>
#include<cmath>
#include<cstdlib>
using namespace std;
#define pi 3.1415926535897932384626433832795
bool L1_flag=1;
//normal distribution
double rand(double min, double max)//产生一个[min,max]之间的随机数
{
return min+(max-min)*rand()/(RAND_MAX+1.0);
}
double normal(double x, double miu,double sigma)//返回x的概率密度
{
return 1.0/sqrt(2*pi)/sigma*exp(-1*(x-miu)*(x-miu)/(2*sigma*sigma));
}
double randn(double miu,double sigma, double min ,double max)//返回一个大于或等于均值miu的概率密度并且属于[min,max]的数
{
double x,y,dScope;
do{
x=rand(min,max);
y=normal(x,miu,sigma);
dScope=rand(0.0,normal(miu,miu,sigma));
}while(dScope>y);
return x;
}
double sqr(double x)
{
return x*x;
}
double vec_len(vector<double> &a)//返回向量a的模
{
double res=0;
for (int i=0; i<a.size(); i++)
res+=a[i]*a[i];
res = sqrt(res);
return res;
}
string version;
char buf[100000],buf1[100000];
int relation_num,entity_num;
map<string,int> relation2id,entity2id;
map<int,string> id2entity,id2relation;
map<int,map<int,int> > left_entity,right_entity;
map<int,double> left_num,right_num;
class Train{
public:
map<pair<int,int>, map<int,int> > ok;
void add(int x,int y,int z)//通过头、尾、关系的id分别添加到相应的数组中,并标记这个元组
{
fb_h.push_back(x);
fb_r.push_back(z);
fb_l.push_back(y);
ok[make_pair(x,z)][y]=1;
}
void run(int n_in,double rate_in,double margin_in,int method_in)
{
n = n_in;//嵌入维数
rate = rate_in;
margin = margin_in;
method = method_in;
/*分别设置关系向量和实体向量的数目和维数*/
relation_vec.resize(relation_num);
for (int i=0; i<relation_vec.size(); i++)
relation_vec[i].resize(n);
entity_vec.resize(entity_num);
for (int i=0; i<entity_vec.size(); i++)
entity_vec[i].resize(n);
relation_tmp.resize(relation_num);
for (int i=0; i<relation_tmp.size(); i++)
relation_tmp[i].resize(n);
entity_tmp.resize(entity_num);
for (int i=0; i<entity_tmp.size(); i++)
entity_tmp[i].resize(n);
/*分别设置关系向量和实体向量的数目和维数*/
for (int i=0; i<relation_num; i++)//论文中提到的对关系向量进行归一化处理
{
for (int ii=0; ii<n; ii++)
relation_vec[i][ii] = randn(0,1.0/n,-6/sqrt(n),6/sqrt(n));
}
for (int i=0; i<entity_num; i++)//对实体向量进行归一化处理
{
for (int ii=0; ii<n; ii++)
entity_vec[i][ii] = randn(0,1.0/n,-6/sqrt(n),6/sqrt(n));
norm(entity_vec[i]);//限制每个实体向量的模在1以内
}
bfgs();
}
private:
int n,method;
double res;
double count,count1;
double rate,margin;
double belta;
vector<int> fb_h,fb_l,fb_r;
vector<vector<int> > feature;
vector<vector<double> > relation_vec,entity_vec;
vector<vector<double> > relation_tmp,entity_tmp;
double norm(vector<double> &a)//限制实体向量a的模在1以内
{
double x = vec_len(a);
if (x>1)
for (int ii=0; ii<a.size(); ii++)
a[ii]/=x;
return 0;
}
int rand_max(int x)//返回一个属于区间[0,x)的整数
{
int res = (rand()*rand())%x;
while (res<0)
res+=x;
return res;
}
void bfgs()
{
res=0;
int nbatches=100;
int nepoch = 1000;//最大迭代次数
int batchsize = fb_h.size()/nbatches;
for (int epoch=0; epoch<nepoch; epoch++)
{
res=0;
for (int batch = 0; batch<nbatches; batch++)
{
relation_tmp=relation_vec;
entity_tmp = entity_vec;
for (int k=0; k<batchsize; k++)
{
int i=rand_max(fb_h.size());//从头实体数组中产生一个下标
int j=rand_max(entity_num);//从总实体数组中产生一个下标
double pr = 1000*right_num[fb_r[i]]/(right_num[fb_r[i]]+left_num[fb_r[i]]);//计算是替换头实体还是尾实体的概率
if (method ==0)
pr = 500;
if (rand()%1000<pr)
{
while (ok[make_pair(fb_h[i],fb_r[i])].count(j)>0)
j=rand_max(entity_num);
train_kb(fb_h[i],fb_l[i],fb_r[i],fb_h[i],j,fb_r[i]);//训练正确元组和替换了尾实体的元组
}
else
{
while (ok[make_pair(j,fb_r[i])].count(fb_l[i])>0)
j=rand_max(entity_num);
train_kb(fb_h[i],fb_l[i],fb_r[i],j,fb_l[i],fb_r[i]);//训练正确元组和替换了头实体的元组
}
/*限制调整后的实体关系的模小于1*/
norm(relation_tmp[fb_r[i]]);
norm(entity_tmp[fb_h[i]]);
norm(entity_tmp[fb_l[i]]);
norm(entity_tmp[j]);
/*限制调整后的实体关系的模小于1*/
}
relation_vec = relation_tmp;
entity_vec = entity_tmp;
}
cout<<"epoch:"<<epoch<<' '<<res<<endl;
//保存关系向量和实体向量的值
FILE* f2 = fopen(("relation2vec."+version).c_str(),"w");
FILE* f3 = fopen(("entity2vec."+version).c_str(),"w");
for (int i=0; i<relation_num; i++)
{
for (int ii=0; ii<n; ii++)
fprintf(f2,"%.6lf\t",relation_vec[i][ii]);
fprintf(f2,"\n");
}
for (int i=0; i<entity_num; i++)
{
for (int ii=0; ii<n; ii++)
fprintf(f3,"%.6lf\t",entity_vec[i][ii]);
fprintf(f3,"\n");
}
fclose(f2);
fclose(f3);
}
}
double res1;
double calc_sum(int e1,int e2,int rel)//计算实体e2和e1+rel的距离
{
double sum=0;
if (L1_flag)
for (int ii=0; ii<n; ii++)
sum+=fabs(entity_vec[e2][ii]-entity_vec[e1][ii]-relation_vec[rel][ii]);//L1距离
else
for (int ii=0; ii<n; ii++)
sum+=sqr(entity_vec[e2][ii]-entity_vec[e1][ii]-relation_vec[rel][ii]);//L2距离
return sum;
}
void gradient(int e1_a,int e2_a,int rel_a,int e1_b,int e2_b,int rel_b)//梯度下降
{
for (int ii=0; ii<n; ii++)
{
double x = 2*(entity_vec[e2_a][ii]-entity_vec[e1_a][ii]-relation_vec[rel_a][ii]);
if (L1_flag)
if (x>0)
x=1;
else
x=-1;
/*缩短正确元组中的e2_a和e1_a+rel_a的距离*/
relation_tmp[rel_a][ii]-=-1*rate*x;
entity_tmp[e1_a][ii]-=-1*rate*x;
entity_tmp[e2_a][ii]+=-1*rate*x;
/*缩短正确元组中的e2_a和e1_a+rel_a的距离*/
x = 2*(entity_vec[e2_b][ii]-entity_vec[e1_b][ii]-relation_vec[rel_b][ii]);
if (L1_flag)
if (x>0)
x=1;
else
x=-1;
/*增加错误元组中的e2_b和e1_b+rel_b的距离*/
relation_tmp[rel_b][ii]-=rate*x;
entity_tmp[e1_b][ii]-=rate*x;
entity_tmp[e2_b][ii]+=rate*x;
/*增加错误元组中的e2_b和e1_b+rel_b的距离*/
}
}
void train_kb(int e1_a,int e2_a,int rel_a,int e1_b,int e2_b,int rel_b)
{
double sum1 = calc_sum(e1_a,e2_a,rel_a);
double sum2 = calc_sum(e1_b,e2_b,rel_b);
if (sum1+margin>sum2)
{
res+=margin+sum1-sum2;
gradient( e1_a, e2_a, rel_a, e1_b, e2_b, rel_b);
}
}
};
Train train;
void prepare()
{
FILE* f1 = fopen("../data/entity2id.txt","r");
FILE* f2 = fopen("../data/relation2id.txt","r");
int x;
while (fscanf(f1,"%s%d",buf,&x)==2)//读入实体
{
string st=buf;
entity2id[st]=x;
id2entity[x]=st;
entity_num++;
}
while (fscanf(f2,"%s%d",buf,&x)==2)//读入关系
{
string st=buf;
relation2id[st]=x;
id2relation[x]=st;
relation_num++;
}
FILE* f_kb = fopen("../data/train.txt","r");
while (fscanf(f_kb,"%s",buf)==1)//读入训练集
{
string s1=buf;
fscanf(f_kb,"%s",buf);
string s2=buf;
fscanf(f_kb,"%s",buf);
string s3=buf;
if (entity2id.count(s1)==0)
{
cout<<"miss entity:"<<s1<<endl;
}
if (entity2id.count(s2)==0)
{
cout<<"miss entity:"<<s2<<endl;
}
if (relation2id.count(s3)==0)
{
relation2id[s3] = relation_num;
relation_num++;
}
left_entity[relation2id[s3]][entity2id[s1]]++;//元组(左实体id,关系id)的个数加1
right_entity[relation2id[s3]][entity2id[s2]]++;//元组(关系id,右实体id)的个数加1
train.add(entity2id[s1],entity2id[s2],relation2id[s3]);//添加元组(左实体id,关系id,右实体id)
}
for (int i=0; i<relation_num; i++)
{
double sum1=0,sum2=0;
for (map<int,int>::iterator it = left_entity[i].begin(); it!=left_entity[i].end(); it++)
{
sum1++;
sum2+=it->second;
}
left_num[i]=sum2/sum1;//计算在关系i下,训练集中左实体的平均id
}
for (int i=0; i<relation_num; i++)
{
double sum1=0,sum2=0;
for (map<int,int>::iterator it = right_entity[i].begin(); it!=right_entity[i].end(); it++)
{
sum1++;
sum2+=it->second;
}
right_num[i]=sum2/sum1;//计算在关系i下,训练集中右实体的平均id
}
cout<<"relation_num="<<relation_num<<endl;
cout<<"entity_num="<<entity_num<<endl;
fclose(f_kb);
}
int ArgPos(char *str, int argc, char **argv) {
int a;
for (a = 1; a < argc; a++)
if (!strcmp(str, argv[a]))
{
if (a == argc - 1)
{
printf("Argument missing for %s\n", str);
exit(1);
}
return a;
}
return -1;
}
int main(int argc,char**argv)
{
srand((unsigned) time(NULL));
int method = 1;
int n = 100;
double rate = 0.001;
double margin = 1;
int i;
if ((i = ArgPos((char *)"-size", argc, argv)) > 0) n = atoi(argv[i + 1]);
if ((i = ArgPos((char *)"-margin", argc, argv)) > 0) margin = atoi(argv[i + 1]);
if ((i = ArgPos((char *)"-method", argc, argv)) > 0) method = atoi(argv[i + 1]);
cout<<"size = "<<n<<endl;
cout<<"learing rate = "<<rate<<endl;
cout<<"margin = "<<margin<<endl;
if (method)
version = "bern";
else
version = "unif";
cout<<"method = "<<version<<endl;
prepare();
train.run(n,rate,margin,method);
}
</span>
Train_TransE代码注释
最新推荐文章于 2023-05-10 17:31:46 发布