2021SC@SDUSC
目录
上篇博客介绍了list_encode类,lseq_encode类和graph_encode类,接下来会对Beam类和splanner类进行介绍。
Beam类介绍
class Beam(): def __init__(self,words,scores,hs,cs,last,beamsz,k,vsz)://方法有九个参数 self.beamsz = beamsz self.k = k self.beam = [] for i in range(beamsz): self.beam.append(beam_obj(words[i].item(),scores[i].item(),hs[i],cs[i],last[i]))//附加 self.done = [] self.vocabsz = vsz def sort(self,norm=True)://方法有两个参数 if len(self.done)<self.beamsz: self.done.extend(self.beam)//添加 self.done = self.done[:self.beamsz] if norm: self.done = sorted(self.done,key=lambda x:x.score/len(x.words),reverse=True)//排序 else: self.done = sorted(self.done,key=lambda x:x.score,reverse=True)//排序 def dup_obj(self,obj)://方法有两个参数 new_obj = beam_obj(None,None,None,None,None) new_obj.words = [x for x in obj.words] new_obj.score = obj.score new_obj.prevent = obj.prevent new_obj.firstwords = [x for x in obj.firstwords] new_obj.isstart = obj.isstart return new_obj//返回结果 def getwords(self): return tt.LongTensor([[x.words[-1]] for x in self.beam])//64位整型 def geth(self): return torch.cat([x.h for x in self.beam],dim=0)//在给定维度上对输入的张量序列seq进行连接操作,返回结果 def getc(self): return torch.cat([x.c for x in self.beam],dim=0)//在给定维度上对输入的张量序列seq进行连接操作,返回结果 def getlast(self): return torch.cat([x.last for x in self.beam],dim=0)//在给定维度上对输入的张量序列seq进行连接操作,返回结果 def getscores(self): return tt.FloatTensor([[x.score] for x in self.beam]).repeat(1,self.k)//64位整型 def getPrevEnt(self): return [x.prevent for x in self.beam]//返回结果 def getIsStart(self): return [(i, self.beam[i].firstwords) for i in range(len(self.beam)) if self.beam[i].isstart]//返回结果 def update(self,scores,words,hs,cs,lasts)://可更新的情况 beam = self.beam scores = scores.squeeze()//获得维度 words = words.squeeze()//获得维度 k = self.k gotscores = self.getscores() scores = scores + self.getscores() scores, idx = scores.view(-1).topk(len(self.beam)) newbeam = [] for i,x in enumerate(idx): x = x.item() r = x//k; c = x%k w = words.view(-1)[x].item() new_obj = self.dup_obj(beam[r]) if w == self.endtok: new_obj.score = scores[i] self.done.append(new_obj)//附加 else: if new_obj.isstart: new_obj.isstart = False new_obj.firstwords.append(w)//附加 if w >= self.vocabsz: new_obj.prevent = w if w == self.eostok: new_obj.isstart = True new_obj.words.append(w)//附加 new_obj.score = scores[i] new_obj.h = hs[r,:].unsqueeze(0)//对数据维度进行扩充 new_obj.c = cs[r,:].unsqueeze(0)//对数据维度进行扩充 new_obj.last = lasts[r,:].unsqueeze(0)//对数据维度进行扩充 newbeam.append(new_obj)//附加 self.beam = newbeam return newbeam != []//返回结果
splanner类介绍
class splanner(nn.Module)://函数的参数是nn.Module类 def __init__(self,args): super().__init__()//继承父类所有的特性(而不是基类),并且避免重复继承 asz = 50 self.emb = nn.Parameter(torch.zeros(1,3,asz))//将一个不可训练的类型为Tensor的参数转化为可训练的类型为parameter的参数,并将这个参数绑定到module里面,成为module中可训练的参数。 nn.init.xavier_normal_(self.emb) self.gru = nn.GRUCell(asz,asz)//输入数据X的特征值的数目和隐藏层的神经元数量,也就是隐藏层的特征数量 self.clin = nn.Linear(args.hsz,asz)//输入样本大小,输出样本大小,该层会学习加性偏差 self.klin = nn.Linear(args.hsz,asz)//输入样本大小,输出样本大小,该层会学习加性偏差 def attend(self,dec,emb,emask)://附加方法 dec = dec.unsqueeze(1)//对数据维度进行扩充 unnorm = torch.bmm(dec,emb.transpose(1,2)) unnorm.masked_fill_(emask,-float('inf')) attn = F.softmax(unnorm,dim=2)//就是对unnorm矩阵中所有第2维下标不同,其他维下标均相同的元素进行操作(softmax) return attn//返回结果 def plan_decode(self,hx,keys,mask,entlens): entlens = entlens[0] e = self.emb hx = self.clin(hx) keys = self.klin(keys) keysleft = keys.size(1) print(keysleft) keys = torch.cat((e,keys),1) //在给定维度上对输入的张量序列seq进行连接操作 unmask = torch.zeros(hx.size(0),1,3).byte().cuda() print(mask) mask = torch.cat((unmask,mask),2)//在给定维度上对输入的张量序列seq进行连接操作 print(mask) ops = [] prev = keys[:,1,:] while keysleft>1: hx = self.gru(prev,hx) print(hx.size(),keys.size()) a = self.attend(hx,keys,mask)//附加 print(a) sel = a.max(2)[1].squeeze() print(sel) ops.append(keys[:,sel])//附加 if sel > 2 and sel != entlens: mask[0,0,sel]=1 keysleft-=1 if sel <= 2: mask[0,0,sel] = 1 else: mask[0,0,:2] = 0 if sel == entlens: mask[0,0,entlens] = 1 else: mask[0,0,entlens] = 0 prev = keys[:,sel] ops = torch.cat(ops,1)//在给定维度上对输入的张量序列seq进行连接操作 exit() return ops//返回结果 def forward(self,hx,keys,mask,entlens,gold=None): e = self.emb.repeat(hx.size(0),1,1)//进行重复 hx = self.clin(hx) keys = self.klin(keys) gold = gold[0] keys = torch.cat((e,keys),1)//在给定维度上对输入的张量序列seq进行连接操作 gscaler = torch.arange(hx.size(0)).long().cuda()*keys.size(1)//结果不包含end unmask = torch.zeros(hx.size(0),1,3).byte().cuda() mask = torch.cat((unmask,mask),2)//在给定维度上对输入的张量序列seq进行连接操作 ops = [] goldup = gold.masked_fill(gold<3,0)//进行填充 for i,j in enumerate(entlens): goldup[i,j]=0 prev = keys[:,1,:] for i in range(gold.size(1)): hx = self.gru(prev,hx) a = self.attend(hx,keys,mask)//附加 mask = mask.view(-1).index_fill(0,goldup[:,i]+gscaler,1).view_as(mask) ops.append(a) prev = keys.view(-1,keys.size(2))[gscaler] ops = torch.cat(ops,1)//在给定维度上对输入的张量序列seq进行连接操作 return ops//返回结果