Transformer-XL模型 核心源码阅读笔记

参考链接

Adaptive softmax 代码实现

#adaptive softmax
#将词汇表按照频率分成不同的 cluster
#每个cluster都有不同维度的embedding矩阵
#在每个cluster内先将隐藏层状态向量经过一个映射d_proj,将其维度变成与该cluster的embedding矩阵相同的维度
#再经过对应的embedding矩阵将其映射到词汇表得到logit(非规范化概率),在得到对数概率
class ProjectedAdaptiveLogSoftmax(nn.Module):
    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
                 keep_order=False):
        super(ProjectedAdaptiveLogSoftmax, self).__init__()

        self.n_token = n_token #词汇表长度
        self.d_embed = d_embed
        self.d_proj = d_proj

        self.cutoffs = cutoffs + [n_token]
        self.cutoff_ends = [0] + self.cutoffs
        self.div_val = div_val

        self.shortlist_size = self.cutoffs[0]
        self.n_clusters = len(self.cutoffs) - 1  #词汇表分类数
        self.head_size = self.shortlist_size + self.n_clusters #根节点元素个数

        if self.n_clusters > 0:#为实现根部的映射而设计的参数
            self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
            self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))

        self.out_layers = nn.ModuleList()
        self.out_projs = nn.ParameterList()
        #这里词汇表各cluster都由一个线性映射层(d_proj维度转变为d_embed维度)加输出层(d_proj投影到词汇表中的词上n_token),
        if div_val == 1: # full softmax 情况
            for i in range(len(self.cutoffs)):
                if d_proj != d_embed:
                    self.out_projs.append(
                        nn.Parameter(torch.Tensor(d_proj, d_embed))
                    )
                else:
                    self.out_projs.append(None)

            self.out_layers.append(nn.Linear(d_embed, n_token))
        else:
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
                d_emb_i = d_embed // (div_val ** i)  #每个类都有一个映射矩阵,这是映射维度

                self.out_projs.append(
                    nn.Parameter(torch.Tensor(d_proj, d_emb_i))
                )
                self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))

        self.keep_order = keep_order
    #实现的:将hidden经过线性映射proj得到与weight相同维度
    #再经过(weight,bias)映射得到预测向量
    def _compute_logit(self, hidden, weight, bias, proj):
        if proj is None:
            logit = F.linear(hidden, weight, bias=bias)
        else:
            # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
            proj_hid = F.linear(hidden, proj.t().contiguous())
            logit = F.linear(proj_hid, weight, bias=bias)
            # else:
            #     logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
            #     if bias is not None:
            #         logit = logit + bias

        return logit

    def forward(self, hidden, target, keep_order=False):
        '''
            hidden :: [len*bsz x d_proj] d_proj=d_model(一般是模型维度,embeding维度)
            target :: [len*bsz]
        '''

        if hidden.size(0) != target.size(0):
            raise RuntimeError('Input and target should have the same size '
                               'in the batch dimension.')

        if self.n_clusters == 0:#full softmax
            logit = self._compute_logit(hidden, self.out_layers[0].weight,
                                        self.out_layers[0].bias, self.out_projs[0])
            #log_softmax并取出目标词的对数概率
            nll = -F.log_softmax(logit, dim=-1) \
                    .gather(1, target.unsqueeze(1)).squeeze(1)
        else:
            # construct weights and biases
            weights, biases = [], []
            for i in range(len(self.cutoffs)):
                if self.div_val == 1:#每个类的映射维度一样
                    l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
                    weight_i = self.out_layers[0].weight[l_idx:r_idx]
                    bias_i = self.out_layers[0].bias[l_idx:r_idx]
                else:
                    weight_i = self.out_layers[i].weight
                    bias_i = self.out_layers[i].bias

                if i == 0:#拼接根部的矩阵
                    weight_i = torch.cat(
                        [weight_i, self.cluster_weight], dim=0)
                    bias_i = torch.cat(
                        [bias_i, self.cluster_bias], dim=0)

                weights.append(weight_i)
                biases.append(bias_i)
            #根部预测
            head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]

            head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
            head_logprob = F.log_softmax(head_logit, dim=1)#[len*bsz xhead_size]
            #记录目标词汇的对数概率的向量
            nll = torch.zeros_like(target,dtype=hidden.dtype, device=hidden.device)

            offset = 0
            cutoff_values = [0] + self.cutoffs
            for i in range(len(cutoff_values) - 1):
                l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] #区间上下界
                #取出目标词在该区间的样本索引
                mask_i = (target >= l_idx) & (target < r_idx)
                indices_i = mask_i.nonzero().squeeze()

                if indices_i.numel() == 0:#该区间没有词
                    continue

                target_i = target.index_select(0, indices_i) - l_idx
                head_logprob_i = head_logprob.index_select(0, indices_i)#[在当前区间的词数 xhead_size]

                if i == 0:
                    logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)  #预测次在头部直接取出对数概率
                else:
                    weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]

                    hidden_i = hidden.index_select(0, indices_i)

                    tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
                    tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
                    #取出头部第i列的对数概率加上尾部的对数概率
                    logprob_i = head_logprob_i[:, -i] \
                              + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)

                if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
                    nll.index_copy_(0, indices_i, -logprob_i)  #保存原理的顺序
                else:
                    nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)

                offset += logprob_i.size(0)

        return nll

Adaptive embedding 代码实现

#Adaptive Input类
class AdaptiveEmbedding(nn.Module):
    def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, sample_softmax=False):
        """
        :param n_token: 词汇表大小
        :param d_embed: 中间embedding矩阵的维度
        :param d_proj: 最终输出的词embedding向量的维度
        :param cutoffs: 词汇表划分list
        :param div_val: 参数k
        :param sample_softmax:
        """
        super(AdaptiveEmbedding, self).__init__()

        self.n_token = n_token
        self.d_embed = d_embed

        self.cutoffs = cutoffs + [n_token] #词汇表划分list
        self.div_val = div_val  # k
        self.d_proj = d_proj

        self.emb_scale = d_proj ** 0.5  #相当于embedding元素放缩因子

        self.cutoff_ends = [0] + self.cutoffs

        self.emb_layers = nn.ModuleList()
        self.emb_projs = nn.ParameterList()
        if div_val == 1:    #不进行划分
            self.emb_layers.append(
                nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
            )
            if d_proj != d_embed:
                self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
        else:
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
                d_emb_i = d_embed // (div_val ** i)
                self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
                self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))

    def forward(self, inp):
        '''
        :param inp: [sentence_len,batch_size]=[36,4]
        :return: [sentence_len,batch_size,embeding_dim]=[36,4,200]
        '''
        if self.div_val == 1:
            embed = self.emb_layers[0](inp)
            if self.d_proj != self.d_embed:
                embed  = F.linear(embed, self.emb_projs[0])
        else:
            param = next(self.parameters())
            inp_flat = inp.view(-1)  #拉直输入
            emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], 
                dtype=param.dtype, device=param.device) #初始化
            for i in range(len(self.cutoffs)):
                l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]

                mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
                indices_i = mask_i.nonzero().squeeze()  #获取在当前区间的词的索引

                if indices_i.numel() == 0:
                    continue
                # 当前区间词的索引转化为该区间的局部索引
                inp_i = inp_flat.index_select(0, indices_i) - l_idx
                emb_i = self.emb_layers[i](inp_i)
                emb_i = F.linear(emb_i, self.emb_projs[i])

                emb_flat.index_copy_(0, indices_i, emb_i)#按索引将embedding放入emb_flat中

            embed = emb_flat.view(*inp.size(), self.d_proj) #恢复原来的形状

        embed.mul_(self.emb_scale)#每个元素乘self.emb_scale

        return embed

Positional Embedding 代码实现

class PositionalEmbedding(nn.Module):
    def __init__(self, demb):
        super(PositionalEmbedding, self).__init__()
        self.demb = demb  #d_model

        inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) #公式分母序列
        self.register_buffer('inv_freq', inv_freq)#向模块添加持久缓冲区

    def forward(self, pos_seq, bsz=None):
        sinusoid_inp = torch.ger(pos_seq, self.inv_freq) #外积,每个位置一个行
        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)  
        #拼接(有点区别,这里sin在前cos在后)

        if bsz is not None:
            return pos_emb[:,None,:].expand(-1, bsz, -1)
        else:
            return pos_emb[:,None,:]  #增加一个维度

Build the model 代码实现

#feed-forward network layer=FFN层
class PositionwiseFF(nn.Module):
    def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
        super(PositionwiseFF, self).__init__()

        self.d_model = d_model
        self.d_inner = d_inner#FFN层内部映射所使用的维度
        self.dropout = dropout
        #FFN层实现
        self.CoreNet = nn.Sequential(
            nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_inner, d_model),
            nn.Dropout(dropout),
        )

        self.layer_norm = nn.LayerNorm(d_model)

        self.pre_lnorm = pre_lnorm

    def forward(self, inp):
        if self.pre_lnorm:
            ##### layer normalization + positionwise feed-forward
            core_out = self.CoreNet(self.layer_norm(inp))

            ##### residual connection
            output = core_out + inp
        else:
            ##### positionwise feed-forward
            core_out = self.CoreNet(inp)

            ##### residual connection + layer normalization
            output = self.layer_norm(inp + core_out)

        return output

#Multi-Head-Attn 层的父类
class RelMultiHeadAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
                 tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):
        super(RelMultiHeadAttn, self).__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.dropout = dropout
        #生成Q、K、V向量的线性映射
        self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)

        self.drop = nn.Dropout(dropout)#整体的dropout
        self.dropatt = nn.Dropout(dropatt) #attention的dropout
        #将多个头的向量合并映射到d_model,是最后一个输出线性映射
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
        # 层级的正则化
        self.layer_norm = nn.LayerNorm(d_model)
        #一个放缩因子
        self.scale = 1 / (d_head ** 0.5)

        self.pre_lnorm = pre_lnorm#输入层正则化

    def _parallelogram_mask(self, h, w, left=False):
        mask = torch.ones((h, w)).byte()
        m = min(h, w)
        mask[:m,:m] = torch.triu(mask[:m,:m])
        mask[-m:,-m:] = torch.tril(mask[-m:,-m:])

        if left:
            return mask
        else:
            return mask.flip(0)

    def _shift(self, x, qlen, klen, mask, left=False):
        if qlen > 1:
            zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
                                    device=x.device, dtype=x.dtype)
        else:
            zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)

        if left:
            mask = mask.flip(1)
            x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
        else:
            x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)

        x = x_padded.masked_select(mask[:,:,None,None]) \
                    .view(qlen, klen, x.size(2), x.size(3))

        return x
    #转化相对位置:先添加当前词位置为0,在向前推qlen个位置。
    def _rel_shift(self, x, zero_triu=False):
        zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
                               device=x.device, dtype=x.dtype)
        x_padded = torch.cat([zero_pad, x], dim=1)

        x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])

        x = x_padded[1:].view_as(x)

        if zero_triu:
            ones = torch.ones((x.size(0), x.size(1)))
            x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]

        return x

    def forward(self, w, r, attn_mask=None, mems=None):
        raise NotImplementedError

#MultiHeadAttn,多头attention层
class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
    def __init__(self, *args, **kwargs):
        super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
        #最后的线性映射
        self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)

    def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
        '''
        :param w: word embedding=[sentence_len,batch_size,embedding]
        :param r: position embedding=[len,embedding]
        :param r_w_bias: [n_head,d_head]
        :param r_r_bias: [n_head,d_head]
        :param attn_mask:
        :param mems: [mems_len,batch_size,embedding]
        :return:
        '''
        qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)

        if mems is not None:
            cat = torch.cat([mems, w], 0)#上一段与当前段进行拼接
            if self.pre_lnorm:#生成 Q、K、V向量
                w_heads = self.qkv_net(self.layer_norm(cat))
            else:
                w_heads = self.qkv_net(cat)
            r_head_k = self.r_net(r) #位置embeding的线性映射
            #拆分
            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
            w_head_q = w_head_q[-qlen:] #取出当前段词的q向量
        else:
            if self.pre_lnorm:
                w_heads = self.qkv_net(self.layer_norm(w))
            else:
                w_heads = self.qkv_net(w)
            r_head_k = self.r_net(r)

            w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)

        klen = w_head_k.size(0)

        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)           # klen x bsz x n_head x d_head
        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)           # klen x bsz x n_head x d_head

        r_head_k = r_head_k.view(rlen, self.n_head, self.d_head)                # rlen x n_head x d_head

        #### compute attention score
        rw_head_q = w_head_q + r_w_bias     # r_w_bias相当于u向量                # qlen x bsz x n_head x d_head
        AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head
        # 爱因斯坦求和约定(einsum)
        rr_head_q = w_head_q + r_r_bias     # r_r_bias相当于v向量
        BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k))              # qlen x rlen x bsz x n_head
        BD = self._rel_shift(BD) #转化为相对位置

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score.mul_(self.scale)#放缩

        #### compute attention probability
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
                attn_score = attn_score.float().masked_fill(
                    attn_mask[None,:,:,None], -float('inf')).type_as(attn_score)
            elif attn_mask.dim() == 3:
                attn_score = attn_score.float().masked_fill(
                    attn_mask[:,:,:,None], -float('inf')).type_as(attn_score)#将attn_mask中1对应的位置用负无穷填充

        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)  #分数转化为转化为概率
        attn_prob = self.dropatt(attn_prob)

        #### compute attention vector
        attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))

        # [qlen x bsz x n_head x d_head] 将多个头输出拼接
        attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
            output = w + attn_out
        else:
            ##### residual connection + layer normalization
            output = self.layer_norm(w + attn_out)

        return output

#默认语言模型的 decoder块
class RelPartialLearnableDecoderLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, dropout,**kwargs):
        super(RelPartialLearnableDecoderLayer, self).__init__()
        #多头attention层
        self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,d_head, dropout, **kwargs)

        self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm'))

    def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
        """
        :param dec_inp: word embedding=[sentence_len,batch_size,embedding]
        :param r: position embedding=[len,batch_size,embedding]
        :param r_w_bias:[n_head,d_head]
        :param r_r_bias:[n_head,d_head]
        :param dec_attn_mask:
        :param mems:
        :return:
        """
        output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,attn_mask=dec_attn_mask,mems=mems)
        output = self.pos_ff(output)

        return output



###############################################################################
# Build the model(构建模型)
###############################################################################
class MemTransformerLM(nn.Module):
    def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
                 dropout, dropatt, tie_weight=True, d_embed=None, 
                 div_val=1, tie_projs=[False], pre_lnorm=False,
                 tgt_len=None, ext_len=None, mem_len=None, 
                 cutoffs=[], clamp_len=-1,same_length=False):
        """
        :param n_token:  词汇表长度
        :param n_layer: transformer的层数
        :param n_head: self-attention的head数
        :param d_model: 模型维度
        :param d_head: 每个head维数
        :param d_inner:FFN网络中的维度
        :param dropout:global dropout rate
        :param dropatt:attention probability dropout rate
        :param tie_weight:是否 tie the word embedding and softmax weights
        :param d_embed:embedding维度
        :param div_val: adapative input and softmax中的参数 k
        :param tie_projs:一个bool型的list表示是否 tie the word embedding and softmax 映射的权值W
        :param pre_lnorm: 将LayerNorm应用于输入而不是输出
        :param tgt_len:number of tokens to predict(当前段的词数)
        :param ext_len:length of the extended context (扩展上下文的长度:即在当前段前面延迟ext_len个词)
        :param mem_len:length of the retained previous heads(记忆长度:前面段的隐藏层状态记忆)
        :param cutoffs:adapative input and softmax的词汇表切割list
        :param clamp_len:use the same pos embeddings after clamp_len
        :param same_length:use the same attn length for all token
        """
        super(MemTransformerLM, self).__init__()#初始父类
        self.n_token = n_token  #词汇表长度

        d_embed = d_model if d_embed is None else d_embed  #embeding维度
        self.d_embed = d_embed
        self.d_model = d_model  #模型维度
        self.n_head = n_head    #self-attentio head数
        self.d_head = d_head    #每个head的维数
        #Adaptive Embedding层
        self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, div_val=div_val)

        self.drop = nn.Dropout(dropout)

        self.n_layer = n_layer

        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len
        self.max_klen = tgt_len + ext_len + mem_len

        #在transformer—XL语言模型中只需要transformer中的decoder部分
        #构建decoder的ModuleList
        self.layers = nn.ModuleList()
        # the default attention
        for i in range(n_layer):
            self.layers.append(
                RelPartialLearnableDecoderLayer(
                    n_head, d_model, d_head, d_inner, dropout,
                    tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
                    dropatt=dropatt, pre_lnorm=pre_lnorm)
            )

        # use adaptive softmax (including standard softmax)
        self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val)

        if tie_weight:#是否embedding层与softmax层共享权重矩阵
            for i in range(len(self.crit.out_layers)):
                self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
        # 是否embedding层与softmax层共享线性映射
        if tie_projs:
            for i, tie_proj in enumerate(tie_projs):
                if tie_proj and div_val == 1 and d_model != d_embed:
                    self.crit.out_projs[i] = self.word_emb.emb_projs[0]
                elif tie_proj and div_val != 1:
                    self.crit.out_projs[i] = self.word_emb.emb_projs[i]
        self.same_length = same_length
        self.clamp_len = clamp_len
        self._create_params()

    #创建参数
    def _create_params(self):
        # default attention
        self.pos_emb = PositionalEmbedding(self.d_model)#位置embedding
        self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))  #论文中的u向量
        self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))  #论文中的v向量

    def reset_length(self, tgt_len, ext_len, mem_len):
        self.tgt_len = tgt_len
        self.mem_len = mem_len
        self.ext_len = ext_len

    #初始化 0-n_layer层的上段缓存的上文信息(mems)
    def init_mems(self):
        if self.mem_len > 0:#mem_len=36
            mems = []
            param = next(self.parameters())
            for i in range(self.n_layer+1):
                empty = torch.empty(0, dtype=param.dtype, device=param.device)
                mems.append(empty)

            return mems
        else:
            return None
    #mems的更新函数,mens长度始终保持为mlen
    def _update_mems(self, hids, mems, qlen, mlen):
        # does not deal with None
        if mems is None: return None

        # mems is not None
        assert len(hids) == len(mems), 'len(hids) != len(mems)'
        # There are `mlen + qlen` steps that can be cached into mems
        # For the next step, the last `ext_len` of the `qlen` tokens
        # will be used as the extended context. Hence, we only cache
        # the tokens from `mlen + qlen - self.ext_len - self.mem_len`
        # to `mlen + qlen - self.ext_len`.
        with torch.no_grad():
            new_mems = []
            end_idx = mlen + max(0, qlen - 0 - self.ext_len) #减去下一次用于扩展上下文的词的隐藏层状态向量
            beg_idx = max(0, end_idx - self.mem_len)
            for i in range(len(hids)):
                cat = torch.cat([mems[i], hids[i]], dim=0)
                new_mems.append(cat[beg_idx:end_idx].detach())

        return new_mems

    def _forward(self, dec_inp, mems=None):
        #dec_inp=[句子长度qlen,batch_size]=[36,4]
        qlen, bsz = dec_inp.size()
        word_emb = self.word_emb(dec_inp)
        #word_emb=[qlen,batch_size,embeding向量维度d_em]=[36,4,200]

        mlen = mems[0].size(0) if mems is not None else 0
        klen = mlen + qlen   #加上上下文后句子总长度
        if self.same_length: #所有词使用相同长度的attention,即每个词都只关注前面固定范围内的词
            all_ones = word_emb.new_ones(qlen, klen)
            mask_len = klen - self.mem_len
            if mask_len > 0:
                mask_shift_len = qlen - mask_len
            else:
                mask_shift_len = qlen
            dec_attn_mask = (torch.triu(all_ones, 1+mlen)
                    + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
        else:
            dec_attn_mask = torch.triu(
                word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
        #dec_attn_mask=[qlen,klen,1] 生成一个上三角矩阵,1表示需要mask,并增加一个维度。

        hids = [] #记录每层的输入向量

        #构建位置embeding
        # pos_seq=[klen-1,....,0]
        pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
                               dtype=word_emb.dtype)
        if self.clamp_len > 0: #限制最大长度,在最大长度以后使用相同位置embeding
            pos_seq.clamp_(max=self.clamp_len)
        pos_emb = self.pos_emb(pos_seq)  #pos_emb=[klen,1,d_em]=[36,1,200]

        core_out = self.drop(word_emb)
        pos_emb = self.drop(pos_emb)

        hids.append(core_out)#将最底层输入放入hids中

        # decoder
        for i, layer in enumerate(self.layers):
            mems_i = None if mems is None else mems[i]
            core_out = layer(core_out, pos_emb, self.r_w_bias,
                    self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
            hids.append(core_out)


        core_out = self.drop(core_out)#最上层的隐藏层状态输出

        new_mems = self._update_mems(hids, mems, mlen, qlen)

        return core_out, new_mems

    def forward(self, data, target, *mems):
        '''
        :param data: 输入数据:[句子长度qlen,batch_size]=[36, 4]
        :param target: [句子长度qlen,batch_size]=[36, 4]
        :param mems: 空或者[n_layer(层数),qlen,batch_size,d_embeding]==[4,36, 4,200]
        :return:
        '''

        if not mems: mems = self.init_mems() #初始化mems

        tgt_len = target.size(0)
        hidden, new_mems = self._forward(data, mems=mems)
        #hidden=[36,4,200]
        pred_hid = hidden[-tgt_len:]

        loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
        loss = loss.view(tgt_len, -1)

        if new_mems is None:
            return [loss]
        else:
            return [loss] + new_mems

测试代码

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='unit test')

    parser.add_argument('--n_layer', type=int, default=4, help='')
    parser.add_argument('--n_rel_layer', type=int, default=4, help='')
    parser.add_argument('--n_head', type=int, default=2, help='')
    parser.add_argument('--d_head', type=int, default=2, help='')
    parser.add_argument('--d_model', type=int, default=200, help='')
    parser.add_argument('--d_embed', type=int, default=200, help='')
    parser.add_argument('--d_inner', type=int, default=200, help='')
    parser.add_argument('--dropout', type=float, default=0.0, help='')
    parser.add_argument('--cuda', action='store_true', help='')
    parser.add_argument('--seed', type=int, default=1111, help='')
    parser.add_argument('--multi_gpu', action='store_true', help='')

    args = parser.parse_args()

    device = torch.device("cuda" if args.cuda else "cpu")

    B = 4 #batch_size
    #tgt_len:number of tokens to predict(当前段的词数)
    #mem_len:length of the retained previous heads
    #ext_len:length of the extended context(每一段前面扩展的词数)
    tgt_len, mem_len, ext_len = 36, 36, 0
    data_len = tgt_len * 20
    args.n_token = 10000

    import data_utils

    data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device)#data=[2880]
    diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len)


    cutoffs = [args.n_token // 2]
    tie_projs = [False] + [True] * len(cutoffs)

    for div_val in [1, 2]:
        for d_embed in [200, 100]:
            model = MemTransformerLM(args.n_token, args.n_layer, args.n_head,
                            args.d_model, args.d_head, args.d_inner, args.dropout,
                            dropatt=args.dropout, tie_weight=True, 
                            d_embed=d_embed, div_val=div_val, 
                            tie_projs=tie_projs, pre_lnorm=True,
                            tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, 
                            cutoffs=cutoffs).to(device)

            print(sum(p.numel() for p in model.parameters()))
            '''
            数据说明:
                data=[740,4],每一列数一个句子,batch_size=1;
                每次截取36词作为一个样本,训练语言模型
            '''
            mems = tuple()
            for idx, (inp, tgt, seqlen) in enumerate(diter):
                print('batch {}'.format(idx))
                out = model(inp, tgt, *mems)
                mems = out[1:]
  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值