知识图到文本的生成——拾贰

2021SC@SDUSC

目录

model类的剩余部分


介绍完newmodel.py中model类的一部分,接下来开始分析model类的剩余部分。

model类的剩余部分

  

def beam_generate(self,b,beamsz,k)://方法有四个参数
  if self.args.title:
    tencs,_ = self.tenc(b.src)
    tmask = self.maskFromList(tencs.size(),b.src[1]).unsqueeze(1)//调用maskFromList方法
  ents = b.ent
  entlens = ents[2]
  ents = self.le(ents)
  if self.graph:
    gents,glob,grels = self.ge(b.rel[0],b.rel[1],(ents,entlens))
    hx = glob
    keys,mask = grels
    mask = mask==0
  else:
    mask = self.maskFromList(ents.size(),entlens)//调用maskFromList
    hx = ents.max(dim=1)[0]
    keys =ents
  mask = mask.unsqueeze(1)//对数据维度进行扩充
  if self.args.plan:
    planlogits = self.splan.plan_decode(hx,keys,mask.clone(),entlens)
    print(planlogits.size())
    sorder = ' '.join([str(x) for x in planlogits.max(1)[1][0].tolist()])
    print(sorder)
    sorder = [x.strip() for x in sorder.split("-1")]
    sorder = [[int(y) for y in x.strip().split(" ")] for x in sorder]
    mask.fill_(0)
    planplace = torch.zeros(hx.size(0)).long()//生成空矩阵
    for i,m in enumerate(sorder):
      mask[i][0][m[0]]=1
  else:
    planlogits = None
  cx = torch.tensor(hx)//生成新张量
  a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1)
  if self.args.title:
    a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1)
    a = torch.cat((a,a2),1)//在给定维度上对输入的张量序列seq进行连接操作
  outputs = []
  outp = torch.LongTensor(ents.size(0),1).fill_(self.starttok).cuda()//64位整型
  beam = None
  for i in range(self.maxlen):
    op = self.emb_w_vertex(outp.clone(),b.nerd)
    if self.args.plan:
      schange = op==self.args.dottok
      if schange.nonzero().size(0)>0:
        print(schange, planplace, sorder)
        planplace[schange.nonzero().squeeze()]+=1
        for j in schange.nonzero().squeeze(1):
          if planplace[j]<len(sorder[j]):
            mask[j] = 0
            m = sorder[j][planplace[j]]
            mask[j][0][sorder[j][planplace[j]]]=1
    op = self.emb(op).squeeze(1)
    prev = torch.cat((a,op),1)//在给定维度上对输入的张量序列seq进行连接操作
    hx,cx = self.lstm(prev,(hx,cx))
    a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1)
    if self.args.title:
      a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1)
      a = torch.cat((a,a2),1)//在给定维度上对输入的张量序列seq进行连接操作
    l = torch.cat((hx,a),1).unsqueeze(1)//在给定维度上对输入的张量序列seq进行连接操作
    s = torch.sigmoid(self.switch(l))//利用自定义损失函数进行损失计算
    o = self.out(l)
    o = torch.softmax(o,2)//就是对o矩阵中所有第2维下标不同,其他维下标均相同的元素进行操作(softmax)
    o = s*o
    _, z = self.mattn(l,(ents,entlens))
    #z = torch.softmax(z,2)//就是对z矩阵中所有第2维下标不同,其他维下标均相同的元素进行操作(softmax)
    z = (1-s)*z
    o = torch.cat((o,z),2)//在给定维度上对输入的张量序列seq进行连接操作
    o[:,:,0].fill_(0)//进行填充
    o[:,:,1].fill_(0)//进行填充
    o = o+(1e-6*torch.ones_like(o))
    decoded = o.log()
    scores, words = decoded.topk(dim=2,k=k)
    if not beam:
      beam = Beam(words.squeeze(),scores.squeeze(),[hx for i in range(beamsz)],
                [cx for i in range(beamsz)],[a for i in range(beamsz)],beamsz,k,self.args.ntoks)
      beam.endtok = self.endtok
      beam.eostok = self.eostok
      keys = keys.repeat(len(beam.beam),1,1)//进行重复
      mask = mask.repeat(len(beam.beam),1,1)//进行重复
      if self.args.title:
        tencs = tencs.repeat(len(beam.beam),1,1)//进行重复
        tmask = tmask.repeat(len(beam.beam),1,1)//进行重复
      if self.args.plan:
        planplace= planplace.unsqueeze(0).repeat(len(beam.beam),1)
        sorder = sorder*len(beam.beam)
      ents = ents.repeat(len(beam.beam),1,1)//进行重复
      entlens = entlens.repeat(len(beam.beam))//进行重复
    else:
      if not beam.update(scores,words,hx,cx,a)://进行更新
        break
      keys = keys[:len(beam.beam)]//进行赋值
      mask = mask[:len(beam.beam)]//进行赋值
      if self.args.title:
        tencs = tencs[:len(beam.beam)]//进行赋值
        tmask = tmask[:len(beam.beam)]//进行赋值
      if self.args.plan:
        planplace= planplace[:len(beam.beam)]//进行赋值
        sorder = sorder[0]*len(beam.beam)//进行赋值
      ents = ents[:len(beam.beam)]//进行赋值
      entlens = entlens[:len(beam.beam)]//进行赋值
    outp = beam.getwords()
    hx = beam.geth()
    cx = beam.getc()
    a = beam.getlast()
  return beam//返回结果
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值