- 参数args.n_batches 设置训练多少次epoch,然后开始训练
- train_order_embedding()函数是模型训练的函数
- 首先训练每个节点的图嵌入向量
对每一批数据,经过图嵌入模型得到向量,对模型预测结果进行损失计算
论文中有损失计算公式
损失反向传播,并优化参数
- 对于clf_model模型
计算emb_a和emb_b之间的匹配分数(论文中目标函数E),
将该E值输入到clf_model线形层中,用交叉熵损失训练该clf_model线性模型。
最后argmax输出二分类结果(匹配 or 不匹配),根据结果得到预测值pred和准确率
其中预测值也是论文中的目标函数E的公式计算
- 最后yield返回训练参数即可