2021SC@SDUSC
目录
介绍完newmodel.py中model类的一部分,接下来开始分析model类的剩余部分。
model类的剩余部分
def beam_generate(self,b,beamsz,k)://方法有四个参数 if self.args.title: tencs,_ = self.tenc(b.src) tmask = self.maskFromList(tencs.size(),b.src[1]).unsqueeze(1)//调用maskFromList方法 ents = b.ent entlens = ents[2] ents = self.le(ents) if self.graph: gents,glob,grels = self.ge(b.rel[0],b.rel[1],(ents,entlens)) hx = glob keys,mask = grels mask = mask==0 else: mask = self.maskFromList(ents.size(),entlens)//调用maskFromList hx = ents.max(dim=1)[0] keys =ents mask = mask.unsqueeze(1)//对数据维度进行扩充 if self.args.plan: planlogits = self.splan.plan_decode(hx,keys,mask.clone(),entlens) print(planlogits.size()) sorder = ' '.join([str(x) for x in planlogits.max(1)[1][0].tolist()]) print(sorder) sorder = [x.strip() for x in sorder.split("-1")] sorder = [[int(y) for y in x.strip().split(" ")] for x in sorder] mask.fill_(0) planplace = torch.zeros(hx.size(0)).long()//生成空矩阵 for i,m in enumerate(sorder): mask[i][0][m[0]]=1 else: planlogits = None cx = torch.tensor(hx)//生成新张量 a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1) if self.args.title: a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1) a = torch.cat((a,a2),1)//在给定维度上对输入的张量序列seq进行连接操作 outputs = [] outp = torch.LongTensor(ents.size(0),1).fill_(self.starttok).cuda()//64位整型 beam = None for i in range(self.maxlen): op = self.emb_w_vertex(outp.clone(),b.nerd) if self.args.plan: schange = op==self.args.dottok if schange.nonzero().size(0)>0: print(schange, planplace, sorder) planplace[schange.nonzero().squeeze()]+=1 for j in schange.nonzero().squeeze(1): if planplace[j]<len(sorder[j]): mask[j] = 0 m = sorder[j][planplace[j]] mask[j][0][sorder[j][planplace[j]]]=1 op = self.emb(op).squeeze(1) prev = torch.cat((a,op),1)//在给定维度上对输入的张量序列seq进行连接操作 hx,cx = self.lstm(prev,(hx,cx)) a = self.attn(hx.unsqueeze(1),keys,mask=mask).squeeze(1) if self.args.title: a2 = self.attn2(hx.unsqueeze(1),tencs,mask=tmask).squeeze(1) a = torch.cat((a,a2),1)//在给定维度上对输入的张量序列seq进行连接操作 l = torch.cat((hx,a),1).unsqueeze(1)//在给定维度上对输入的张量序列seq进行连接操作 s = torch.sigmoid(self.switch(l))//利用自定义损失函数进行损失计算 o = self.out(l) o = torch.softmax(o,2)//就是对o矩阵中所有第2维下标不同,其他维下标均相同的元素进行操作(softmax) o = s*o _, z = self.mattn(l,(ents,entlens)) #z = torch.softmax(z,2)//就是对z矩阵中所有第2维下标不同,其他维下标均相同的元素进行操作(softmax) z = (1-s)*z o = torch.cat((o,z),2)//在给定维度上对输入的张量序列seq进行连接操作 o[:,:,0].fill_(0)//进行填充 o[:,:,1].fill_(0)//进行填充 o = o+(1e-6*torch.ones_like(o)) decoded = o.log() scores, words = decoded.topk(dim=2,k=k) if not beam: beam = Beam(words.squeeze(),scores.squeeze(),[hx for i in range(beamsz)], [cx for i in range(beamsz)],[a for i in range(beamsz)],beamsz,k,self.args.ntoks) beam.endtok = self.endtok beam.eostok = self.eostok keys = keys.repeat(len(beam.beam),1,1)//进行重复 mask = mask.repeat(len(beam.beam),1,1)//进行重复 if self.args.title: tencs = tencs.repeat(len(beam.beam),1,1)//进行重复 tmask = tmask.repeat(len(beam.beam),1,1)//进行重复 if self.args.plan: planplace= planplace.unsqueeze(0).repeat(len(beam.beam),1) sorder = sorder*len(beam.beam) ents = ents.repeat(len(beam.beam),1,1)//进行重复 entlens = entlens.repeat(len(beam.beam))//进行重复 else: if not beam.update(scores,words,hx,cx,a)://进行更新 break keys = keys[:len(beam.beam)]//进行赋值 mask = mask[:len(beam.beam)]//进行赋值 if self.args.title: tencs = tencs[:len(beam.beam)]//进行赋值 tmask = tmask[:len(beam.beam)]//进行赋值 if self.args.plan: planplace= planplace[:len(beam.beam)]//进行赋值 sorder = sorder[0]*len(beam.beam)//进行赋值 ents = ents[:len(beam.beam)]//进行赋值 entlens = entlens[:len(beam.beam)]//进行赋值 outp = beam.getwords() hx = beam.geth() cx = beam.getc() a = beam.getlast() return beam//返回结果