SimplE:SimplE Embedding for Link Prediction in Knowledge Graphs+代码


本文主要对知识图谱补全论文进行讲解,对其原理、思路、代码等方面进行详细讲解。

1 介绍

1.1 知识图谱

现实世界中可以通过三元组,对现实世界进行描述,三元组描述的形式为:实体-关系-实体或实体-关系-属性,通常可以用符号表示(h, r, t)其中h表示头实体,r表示实体和实体之间的关系,t表示尾实体。以该方式存储现实世界的各种关系以及属性,由于现实世界是非常的庞大的,各种关系容易缺失,因此涉及知识图谱补全,本文SimplE就是一篇针对知识图谱补全的论文。

1.2 知识图谱补全方法

基础的知识图谱补全涉及以下几种方法

  • 双线性模型(disMult)
  • 神经网络模型 (ConvKB)
  • 转换模型(Trans系列)
  • 图卷积模型(GCN)

1.3知识图谱补全(Knowledge Graph Completion,KGC)

目前主要被抽象成一个预测问题,即预测出三元组中缺失的部分。所以可分成3个子任务:

  • 头实体预测:(?, r, t)
  • 关系预测:(h, ?, t)
  • 尾实体预测:(h, r, ?)

1.4 关系分类

E \mathcal{E} E and R \mathcal{R} R 分别表示实体和关系,一个三元组用 ( h , r , t ) (h, r, t) (h,r,t)表示, 其中 h ∈ E h \in \mathcal{E} hE 是头实体, r ∈ R r \in \mathcal{R} rR 是关系, and t ∈ E t \in \mathcal{E} tE 是尾实体。 ζ \zeta ζ 表示 一个真实的三元组(e.g., (paris, capitalOf, france)), 而 ζ ′ \zeta^{\prime} ζ 表示错误的三元组 (e.g., ( paris, capitalOf, italy)). 知识图谱 K G \mathcal{K} \mathcal{G} KG ζ 的 子 集 . \zeta的子集 . ζ. A relation

  • 对称(symmetric): ( e 1 , r , e 2 ) ∈ ζ ⟺ ( e 2 , r , e 1 ) ∈ ζ \left(e_{1}, r, e_{2}\right) \in \zeta \Longleftrightarrow\left(e_{2}, r, e_{1}\right) \in \zeta (e1,r,e2)ζ(e2,r,e1)ζ,其中 e 1 , e 2 ∈ E e_{1}, e_{2} \in \mathcal{E} e1,e2E
  • 反身 ( reflexive): ( e , r , e ) ∈ ζ (e, r, e) \in \zeta (e,r,e)ζ for all entities e ∈ E e \in \mathcal{E} eE
  • 反对称(anti-symmetric): ( e 1 , r , e 2 ) ∈ ζ ⟺ ( e 2 , r , e 1 ) ∈ ζ ′ \left(e_{1}, r, e_{2}\right) \in \zeta \Longleftrightarrow\left(e_{2}, r, e_{1}\right) \in \zeta^{\prime} (e1,r,e2)ζ(e2,r,e1)ζ
  • 传递(transitive): ( e 1 , r , e 2 ) ∈ ζ ∧ ( e 2 , r , e 3 ) ∈ ζ ⇒ ( e 1 , r , e 3 ) ∈ ζ \left(e_{1}, r, e_{2}\right) \in \zeta \wedge\left(e_{2}, r, e_{3}\right) \in \zeta \Rightarrow\left(e_{1}, r, e_{3}\right) \in \zeta (e1,r,e2)ζ(e2,r,e3)ζ(e1,r,e3)ζ ,其中 e 1 , e 2 , e 3 ∈ E e_{1}, e_{2}, e_{3} \in \mathcal{E} e1,e2,e3E
  • 逆关系(inverse): ( e i , r , e j ) ∈ ζ ⟺ ( e j , r − 1 , e i ) ∈ ζ \left(e_{i}, r, e_{j}\right) \in \zeta \Longleftrightarrow\left(e_{j}, r^{-1}, e_{i}\right) \in \zeta (ei,r,ej)ζ(ej,r1,ei)ζ

2 模型

2.1 双线性模型

所谓的双线性模型即实体关系之间采用乘的方式,其乘的方式为 ⟨ v , w , x ⟩ ≐ ∑ j = 1 d v [ j ] ∗ w [ j ] ∗ x [ j ] \langle v, w, x\rangle \doteq \sum_{j=1}^{d} v[j] * w[j] * x[j] v,w,xj=1dv[j]w[j]x[j],Hadamard乘积,每个元素元素之间进行相乘,然后累加,其中论文disMult也采用同样的方式。

2.2 核心公式

两个向量 h e , t e ∈ R d h_{e}, t_{e} \in \mathbb{R}^{d} he,teRd作为实体 e e e的嵌入,向量 v r , v r − 1 ∈ R a v_{r}, v_{r}^{-1} \in \mathbb{R}^{a} vr,vr1Ra 作为关系 r r r的嵌入。
1 2 ( ⟨ h e i , v r , t e j ⟩ + ⟨ h e j , v r − 1 , t e i ⟩ ) \frac{1}{2}\left(\left\langle h_{e_{i}}, v_{r}, t_{e_{j}}\right\rangle+\left\langle h_{e_{j}}, v_{r^{-1}}, t_{e_{i}}\right\rangle\right) 21(hei,vr,tej+hej,vr1,tei)作为模型的核心,计算得分函数。

2.3 负采样

模型采用随机的方法,对头实体或者尾实体进行负采样,随机从[0,num_ent-1]中抽取一个不同于原始的数据, num_ent表示实体总数。正确的三元组标记label为1,错位的三元组即复杂采样的结果标记label为-1。

2.4 损失函数

min ⁡ θ ∑ ( ( h , r , t ) , l ) ∈ L B softplus ⁡ ( − l ⋅ ϕ ( h , r , t ) ) + λ ∥ θ ∥ 2 2 \min _{\theta} \sum_{((h, r, t), l) \in \mathbf{L B}} \operatorname{softplus}(-l \cdot \phi(h, r, t))+\lambda\|\theta\|_{2}^{2} minθ((h,r,t),l)LBsoftplus(lϕ(h,r,t))+λθ22,其中 θ \theta θ 代表模型参数(embeddings 的参数), l l l表示标签范围为-1或者+1即正确三元组或错误三元组。 ϕ ( h , r , t ) \phi(h, r, t) ϕ(h,r,t)表示三元组 ( h , r , t ) (h, r, t) (h,r,t)的得分函数, softplus ⁡ ( x ) = log ⁡ ( 1 + exp ⁡ ( x ) ) \operatorname{softplus}(x)=\log (1+\exp (x)) softplus(x)=log(1+exp(x))

2.5 评价

知识图谱补全评价指标有hit@n, mrr, mr等方法,博客参考KGE性能指标:MRR,MR,HITS@1,HITS@3,HITS@10

  • MRR

    MRR的全称是Mean Reciprocal Ranking,其中Reciprocal是指“倒数的”的意思。具体的计算方法如下:
    M R R = 1 ∣ S ∣ ∑ i = 1 1 rank ⁡ i = 1 ∣ S ∣ ( 1 rank ⁡ 1 + 1 rank ⁡ 2 + … + 1 rank ⁡ ∣ S ∣ ) \mathrm{MRR}=\frac{1}{|S|} \sum_{i=1} \frac{1}{\operatorname{rank}_{i}}=\frac{1}{|S|}\left(\frac{1}{\operatorname{rank}_{1}}+\frac{1}{\operatorname{rank}_{2}}+\ldots+\frac{1}{\operatorname{rank}_{|S|}}\right) MRR=S1i=1ranki1=S1(rank11+rank21++rankS1),其中S是三元组集合,|S|是三元组集合个数, r a n k i rank_{i} ranki是指第 i i i个三元组的链接预测排名。该指标越大越好。

  • MR

    MR的全称是Mean Rank。具体的计算方法如下:
    M R = 1 ∣ S ∣ ∑ i = 1 ∣ S ∣ rank ⁡ i = 1 ∣ S ∣ ( rank ⁡ 1 + rank ⁡ 2 + … + rank ⁡ ∣ S ∣ ) \mathbf{M R}=\frac{1}{|S|} \sum_{i=1}^{|S|} \operatorname{rank}_{i}=\frac{1}{|S|}\left(\operatorname{rank}_{1}+\operatorname{rank}_{2}+\ldots+\operatorname{rank}_{|S|}\right) MR=S1i=1Sranki=S1(rank1+rank2++rankS)
    上述公式涉及的符号和MRR计算公式中涉及的符号一样。该指标越小越好。

  • HITS@n

    该指标是指在链接预测中排名小于n的三元组的平均占比。具体的计算方法如下:
    HITS ⁡ @ n = 1 ∣ S ∣ ∑ i = 1 ∣ S ∣ I ( rank ⁡ i ⩽ n ) \operatorname{HITS} @ n=\frac{1}{|S|} \sum_{i=1}^{|S|} \mathbb{I}\left(\operatorname{rank}_{i} \leqslant n\right) HITS@n=S1i=1SI(rankin)

    其中,上述公式涉及的符号和MRR计算公式中涉及的符号一样,另外 I ( ⋅ ) \mathbb{I}(\cdot) I() 是indicator函数(若条件真则函数值为1,否则为0)。一般地,取n等于1、3或者10。该指标越大越好。

3 代码

代码包括6个模块,分别为:数据处理,模型模块,训练模块,测试模块,评价模块,主模块等,如图所示:
Refused

3.1 数据处理模块 dataset.py

import numpy as np
import random
import torch
import math

class Dataset:
    def __init__(self, ds_name):
        self.name = ds_name
        self.dir = "datasets/" + ds_name + "/"
        self.ent2id = {}
        self.rel2id = {}
        self.data = {spl: self.read(self.dir + spl + ".txt") for spl in ["train", "valid", "test"]}
        self.batch_index = 0
       
    def read(self, file_path): #读取文件,train,text, valid数据集
        with open(file_path, "r") as f:
            lines = f.readlines()
        
        triples = np.zeros((len(lines), 3)) #初始化数据为(0 0 0)形式
		#[(0, 0, 0),
		#(0, 0, 0),
		#(0, 0, 0)]
        for i, line in enumerate(lines):#填充数据
            triples[i] = np.array(self.triple2ids(line.strip().split("\t")))
        return triples
    
    def num_ent(self): #总实体个数
        return len(self.ent2id)
    
    def num_rel(self):#总关系个数
        return len(self.rel2id)
                     
    def triple2ids(self, triple):#将实体转化成id的形式
        return [self.get_ent_id(triple[0]), self.get_rel_id(triple[1]), self.get_ent_id(triple[2])]
                     
    def get_ent_id(self, ent):#实体转成id的形式
        if not ent in self.ent2id:
            self.ent2id[ent] = len(self.ent2id)
        return self.ent2id[ent]
            
    def get_rel_id(self, rel):#关系转成id的形式
        if not rel in self.rel2id:
            self.rel2id[rel] = len(self.rel2id)
        return self.rel2id[rel]
                     
    def rand_ent_except(self, ent):  #进行随机生成一个[0, num_ent-1]之间,但不能随机生成真实的值,如果生成则扔掉
        rand_ent = random.randint(0, self.num_ent() - 1)
        while(rand_ent == ent):#while循环的作用就是提出真实值
            rand_ent = random.randint(0, self.num_ent() - 1)
        return rand_ent
                     
    def next_pos_batch(self, batch_size):#进行batch操作,每次取一个batch_size,
        if self.batch_index + batch_size < len(self.data["train"]):#数据足够当前batch_size
            batch = self.data["train"][self.batch_index: self.batch_index+batch_size]
            self.batch_index += batch_size
        else:#数据不够当前的batch_size
            batch = self.data["train"][self.batch_index:]
            self.batch_index = 0
        return np.append(batch, np.ones((len(batch), 1)), axis=1).astype("int") #appending the +1 label
		#增加一列,表示lable,因为都是都是正确的三元组,所以lable都是1
		#[(1,1, 3, 1),
		#(1,1, 3, 1),
		#(1,1, 3, 1)]
                     
    def generate_neg(self, pos_batch, neg_ratio):#进行负采样,neg_ratio表示负载样的个数
        neg_batch = np.repeat(np.copy(pos_batch), neg_ratio, axis=0)#对正确的三元组进行复制,复制neg_ratio个
        for i in range(len(neg_batch)):
            if random.random() < 0.5:#表示随机对头实体或者尾实体进行负采样
                neg_batch[i][0] = self.rand_ent_except(neg_batch[i][0]) #flipping head
            else:
                neg_batch[i][2] = self.rand_ent_except(neg_batch[i][2]) #flipping tail
		#负采样所有的label都是-1
        neg_batch[:,-1] = -1
        return neg_batch

    def next_batch(self, batch_size, neg_ratio, device):#取一个batch
        pos_batch = self.next_pos_batch(batch_size)#生成一个争取的pos_batch
        neg_batch = self.generate_neg(pos_batch, neg_ratio)#生成错误的neg_batch
        batch = np.append(pos_batch, neg_batch, axis=0)#将正负batch放一起
        np.random.shuffle(batch)#进行shuffle,混乱
		#提取head, rel, tail, lable
        heads  = torch.tensor(batch[:,0]).long().to(device)
        rels   = torch.tensor(batch[:,1]).long().to(device)
        tails  = torch.tensor(batch[:,2]).long().to(device)
        labels = torch.tensor(batch[:,3]).float().to(device)
        return heads, rels, tails, labels
    
    def was_last_batch(self):
        return (self.batch_index == 0)

    def num_batch(self, batch_size):
        return int(math.ceil(float(len(self.data["train"])) / batch_size))


3.2 模型模块 model.py

import torch
import torch.nn as nn
import math

class SimplE(nn.Module):
    def __init__(self, num_ent, num_rel, emb_dim, device):
        super(SimplE, self).__init__()
        self.num_ent = num_ent
        self.num_rel = num_rel
        self.emb_dim = emb_dim
        self.device = device
		#进行配置两个实体的embedding和两个关系的embedding
		#头实体的embedding
        self.ent_h_embs   = nn.Embedding(num_ent, emb_dim).to(device)
		#尾实体的embedding 
        self.ent_t_embs   = nn.Embedding(num_ent, emb_dim).to(device)
		#正关系的embedding
        self.rel_embs     = nn.Embedding(num_rel, emb_dim).to(device)
		#逆关系的embedding
        self.rel_inv_embs = nn.Embedding(num_rel, emb_dim).to(device)

        sqrt_size = 6.0 / math.sqrt(self.emb_dim)
		#embedding 参数进行初始化
        nn.init.uniform_(self.ent_h_embs.weight.data, -sqrt_size, sqrt_size)
        nn.init.uniform_(self.ent_t_embs.weight.data, -sqrt_size, sqrt_size)
        nn.init.uniform_(self.rel_embs.weight.data, -sqrt_size, sqrt_size)
        nn.init.uniform_(self.rel_inv_embs.weight.data, -sqrt_size, sqrt_size)
        
	#embedding 参数l2范式
    def l2_loss(self):
        return ((torch.norm(self.ent_h_embs.weight, p=2) ** 2) + (torch.norm(self.ent_t_embs.weight, p=2) ** 2) + (torch.norm(self.rel_embs.weight, p=2) ** 2) + (torch.norm(self.rel_inv_embs.weight, p=2) ** 2)) / 2

    def forward(self, heads, rels, tails):
        hh_embs = self.ent_h_embs(heads)
        ht_embs = self.ent_h_embs(tails)
        th_embs = self.ent_t_embs(heads)
        tt_embs = self.ent_t_embs(tails)
        r_embs = self.rel_embs(rels)
        r_inv_embs = self.rel_inv_embs(rels)

        scores1 = torch.sum(hh_embs * r_embs * tt_embs, dim=1)
        scores2 = torch.sum(ht_embs * r_inv_embs * th_embs, dim=1)
		#核心score函数
        return torch.clamp((scores1 + scores2) / 2, -20, 20)
        

3.3 训练模块 Trainer.py

from dataset import Dataset
from SimplE import SimplE
import torch
import torch.nn as nn
import torch.nn.functional as F
import os 

class Trainer:
    def __init__(self, dataset, args):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
		#设置模型
        self.model = SimplE(dataset.num_ent(), dataset.num_rel(), args.emb_dim, self.device)
        #数据集处理类
		self.dataset = dataset
        self.args = args
        
    def train(self):
        self.model.train()

        optimizer = torch.optim.Adagrad(
            self.model.parameters(),
            lr=self.args.lr,
            weight_decay= 0,
            initial_accumulator_value= 0.1 #this is added because of the consistency to the original tensorflow code
        )
		#进行训练
        for epoch in range(1, self.args.ne + 1):
            last_batch = False
            total_loss = 0.0

            while not last_batch:
                #获取一个batch的h, r, t, l,注意是l不是数字1
				h, r, t, l = self.dataset.next_batch(self.args.batch_size, neg_ratio=self.args.neg_ratio, device = self.device)
                last_batch = self.dataset.was_last_batch()
                optimizer.zero_grad()
				#打分函数
                scores = self.model(h, r, t)
				#损失函数
                loss = torch.sum(F.softplus(-l * scores))+ (self.args.reg_lambda * self.model.l2_loss() / self.dataset.num_batch(self.args.batch_size))
                loss.backward()
                optimizer.step()
                total_loss += loss.cpu().item()

            print("Loss in iteration " + str(epoch) + ": " + str(total_loss) + "(" + self.dataset.name + ")")
        
            if epoch % self.args.save_each == 0:
                self.save_model(epoch)

    def save_model(self, chkpnt):#保存模型
        print("Saving the model")
        directory = "models/" + self.dataset.name + "/"
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save(self.model, directory + str(chkpnt) + ".chkpnt")

3.4 测试模块 Test.py

import torch
from dataset import Dataset
import numpy as np
from measure import Measure
from os import listdir
from os.path import isfile, join

class Tester:
    def __init__(self, dataset, model_path, valid_or_test):
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model = torch.load(model_path, map_location = self.device)
        self.model.eval()
        self.dataset = dataset
        self.valid_or_test = valid_or_test
        self.measure = Measure()
        self.all_facts_as_set_of_tuples = set(self.allFactsAsTuples())

	#计算rank,通过和正确数据进行比较
    def get_rank(self, sim_scores):#assuming the test fact is the first one
        return (sim_scores >= sim_scores[0]).sum()
	
	#进行采用,生成所有实体对应的数据
	#[(1, 2, 3),
	#(2, 2, 3),
	#(3, 2, 3),
	#.....
	#(n, 2, 3)]
    def create_queries(self, fact, head_or_tail):
        head, rel, tail = fact
        if head_or_tail == "head":
            return [(i, rel, tail) for i in range(self.dataset.num_ent())]
        elif head_or_tail == "tail":
            return [(head, rel, i) for i in range(self.dataset.num_ent())]
	
	#raw 所有数据
	#fil是提出训练集出现的数据
    def add_fact_and_shred(self, fact, queries, raw_or_fil):
        if raw_or_fil == "raw":
            result = [tuple(fact)] + queries
        elif raw_or_fil == "fil":
            result = [tuple(fact)] + list(set(queries) - self.all_facts_as_set_of_tuples)

        return self.shred_facts(result)

    # def replace_and_shred(self, fact, raw_or_fil, head_or_tail):
    #     ret_facts = []
    #     head, rel, tail = fact
    #     for i in range(self.dataset.num_ent()):
    #         if head_or_tail == "head" and i != head:
    #             ret_facts.append((i, rel, tail))
    #         if head_or_tail == "tail" and i != tail:
    #             ret_facts.append((head, rel, i))

    #     if raw_or_fil == "raw":
    #         ret_facts = [tuple(fact)] + ret_facts
    #     elif raw_or_fil == "fil":
    #         ret_facts = [tuple(fact)] + list(set(ret_facts) - self.all_facts_as_set_of_tuples)

    #     return self.shred_facts(ret_facts)
    
    def test(self):
        settings = ["raw", "fil"] if self.valid_or_test == "test" else ["fil"]
        
        for i, fact in enumerate(self.dataset.data[self.valid_or_test]):
            for head_or_tail in ["head", "tail"]:
                queries = self.create_queries(fact, head_or_tail)
                for raw_or_fil in settings:
                    h, r, t = self.add_fact_and_shred(fact, queries, raw_or_fil)
                    sim_scores = self.model(h, r, t).cpu().data.numpy()
                    rank = self.get_rank(sim_scores)
                    self.measure.update(rank, raw_or_fil)

        self.measure.normalize(len(self.dataset.data[self.valid_or_test]))
        self.measure.print_()
        return self.measure.mrr["fil"]


	#获取h, r, t
    def shred_facts(self, triples):
        heads  = [triples[i][0] for i in range(len(triples))]
        rels   = [triples[i][1] for i in range(len(triples))]
        tails  = [triples[i][2] for i in range(len(triples))]
        return torch.LongTensor(heads).to(self.device), torch.LongTensor(rels).to(self.device), torch.LongTensor(tails).to(self.device)
	#所有的正确的三元组, train, test, valid
    def allFactsAsTuples(self):
        tuples = []
        for spl in self.dataset.data:
            for fact in self.dataset.data[spl]:
                tuples.append(tuple(fact))
        
        return tuples	

3.5 评价模块 Measure.py

class Measure:
    def __init__(self):
        self.hit1  = {"raw": 0.0, "fil": 0.0}
        self.hit3  = {"raw": 0.0, "fil": 0.0}
        self.hit10 = {"raw": 0.0, "fil": 0.0}
        self.mrr   = {"raw": 0.0, "fil": 0.0}
        self.mr    = {"raw": 0.0, "fil": 0.0}

    def update(self, rank, raw_or_fil):
        if rank == 1:
            self.hit1[raw_or_fil] += 1.0
        if rank <= 3:
            self.hit3[raw_or_fil] += 1.0
        if rank <= 10:
            self.hit10[raw_or_fil] += 1.0

        self.mr[raw_or_fil]  += rank
        self.mrr[raw_or_fil] += (1.0 / rank)
    
    def normalize(self, num_facts):
        for raw_or_fil in ["raw", "fil"]:
            self.hit1[raw_or_fil]  /= (2 * num_facts)
            self.hit3[raw_or_fil]  /= (2 * num_facts)
            self.hit10[raw_or_fil] /= (2 * num_facts)
            self.mr[raw_or_fil]    /= (2 * num_facts)
            self.mrr[raw_or_fil]   /= (2 * num_facts)

    def print_(self):
        for raw_or_fil in ["raw", "fil"]:
            print(raw_or_fil.title() + " setting:")
            print("\tHit@1 =",  self.hit1[raw_or_fil])
            print("\tHit@3 =",  self.hit3[raw_or_fil])
            print("\tHit@10 =", self.hit10[raw_or_fil])
            print("\tMR =",     self.mr[raw_or_fil])
            print("\tMRR =",    self.mrr[raw_or_fil])
            print("")

3.6 主模块 Main.py

from trainer import Trainer
from tester import Tester
from dataset import Dataset
import argparse
import time
def get_parameter():
    parser = argparse.ArgumentParser()
    parser.add_argument('-ne', default=1000, type=int, help="number of epochs")
    parser.add_argument('-lr', default=0.1, type=float, help="learning rate")
    parser.add_argument('-reg_lambda', default=0.03, type=float, help="l2 regularization parameter")
    parser.add_argument('-dataset', default="WN18", type=str, help="wordnet dataset")
    parser.add_argument('-emb_dim', default=200, type=int, help="embedding dimension")
    parser.add_argument('-neg_ratio', default=1, type=int, help="number of negative examples per positive example")
    parser.add_argument('-batch_size', default=1415, type=int, help="batch size")
    parser.add_argument('-save_each', default=50, type=int, help="validate every k epochs")
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = get_parameter()
    dataset = Dataset(args.dataset)

    print("~~~~ Training ~~~~")
    trainer = Trainer(dataset, args)
    trainer.train()

    print("~~~~ Select best epoch on validation set ~~~~")
    epochs2test = [str(int(args.save_each * (i + 1))) for i in range(args.ne // args.save_each)]
    dataset = Dataset(args.dataset)
    
    best_mrr = -1.0
    best_epoch = "0"
    for epoch in epochs2test:
        start = time.time()
        print(epoch)
        model_path = "models/" + args.dataset + "/" + epoch + ".chkpnt"
        tester = Tester(dataset, model_path, "valid")
        mrr = tester.test()
        if mrr > best_mrr:
            best_mrr = mrr
            best_epoch = epoch
        print(time.time() - start)

    print("Best epoch: " + best_epoch)

    print("~~~~ Testing on the best epoch ~~~~")
    best_model_path = "models/" + args.dataset + "/" + best_epoch + ".chkpnt"
    tester = Tester(dataset, best_model_path, "test")
    tester.test()

  • 6
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值