CogView中的Transformer

入门小菜鸟,希望像做笔记记录自己学的东西,也希望能帮助到同样入门的人,更希望大佬们帮忙纠错啦~侵权立删。

目录

一、原理

1、总体介绍

2、具体实现

(1)不采取稀疏处理(默认)

(2)采取稀疏训练

​(3)稀疏推断

二、代码解析

1、__init__

(1)参数设定

(2)存储激活检查点标志

(3)定义输出层初始化方法

(4)Position embedding

(5)窗口定义

(6)Transformer layers设置

(7)将 num_layer 个 transformer layer打包在一起,以列表形式保存

(8)output层的LayerNorm处理

(9)激活点检查

2、forward

(1)获取最终的输入层的相关信息

(2)attention mask建立

(3)稀疏训练or推断准备

(4)对输入层的处理

(5)这次是否有产生记忆模块

(6)获取下一层的输入——分为是否采取检查点激活两种情况来分析

(7)最后一层norm

(8)记忆模块更新

(9)返回这一层的输出结果和记忆模块


一、原理

1、总体介绍

将n个的 transformer blocks 打包在一起,即 n * transformer layer + final layernorm 两部分组成

2、具体实现

(1)不采取稀疏处理(默认)

 (2)采取稀疏训练

 新建的rmask(k为输入的总列数;w为窗口大小;t为调整窗口数量所用)

 (3)稀疏推断


二、代码解析

1、__init__

(1)参数设定

class GPT2ParallelTransformer(torch.nn.Module):
    """GPT-2 transformer.

    This module takes input from embedding layer and it's output can
    be used directly by a logit layer. It consists of L (num-layers)
    blocks of:
        layer norm
        self attention
        residual connection
        layer norm
        mlp
        residual connection
    followed by a final layer norm.

    Arguments:
        num_layers: Number of transformer layers.
        hidden_size: The hidden size of the self attention.
        num_attention_heads: number of attention head in the self
                             attention.
        attention_dropout_prob: dropout probability of the attention
                                score in self attention.
        output_dropout_prob: dropout probability for the outputs
                             after self attention and final output.
        checkpoint_activations: if True, checkpoint activations.
        checkpoint_num_layers: number of layers to checkpoint. This
                               is basically the chunk size in checkpoitning.
        layernorm_epsilon: epsilon used in layernorm to avoid
                           division by zero.
        init_method_std: standard deviation of the init method which has
                         the form N(0, std).
        use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers)
                                            scaling for the output weights (
                                            output of self attention and mlp).
    """
    def __init__(self,
                 num_layers,
                 hidden_size,
                 num_attention_heads,
                 max_sequence_length,
                 max_memory_length,
                 embedding_dropout_prob,
                 attention_dropout_prob,
                 output_dropout_prob,
                 checkpoint_activations,
                 checkpoint_num_layers=1,
                 layernorm_epsilon=1.0e-5,
                 init_method_std=0.02,
                 use_scaled_init_for_output_weights=True,
                 query_window=128,
                 key_window_times=6,
                 num_pivot=768
                 ):
        super(GPT2ParallelTransformer, self).__init__()
  • num_layers:transformer层的数量;
  • hidden_size:自我注意力模块的隐藏大小(嵌入向量的维度);
  • num_attention_heads:自我注意力模块中attention head的数量;
  • max_sequence_length:词典大小;
  • max_memory_length:最大记忆长度;
  • embedding_dropout_prob:嵌入层(该模块的输入部分)中元素被dropout的概率(为了解决过拟合问题而随机丢弃一部分元素);
  • attention_dropout_prob:同样道理,注意力模块中注意力得分被dropout的概率;
  • output_dropout_prob:同理,输出层后的输出被dropout的概率;
  • checkpoint_activations:是否执行检查点激活;
  • checkpoint_num_layers:检查点的层数。这基本上是checkpoitning中的块大小;
  • layernorm_epsilon:在layernform中用于避免被零除的ε(用于防止分母为0);
  • init_method_std:初始化方法(使用让权重呈现正态分布的方法)中正态分布的方差;
  • use_scaled_init_for_output_weights:是否对自注意力和mlp的输出的权重调用scaled_init_method进行初始化;
  • query_window:稀疏处理中的窗口大小;
  • key_window_times:用于调整窗口数量;
  • num_pivot:transformer里图像token和文本token的总和数量

(2)存储激活检查点标志

        # Store activation checkpoiting flag.
        #首先先记录是否执行检查点激活,检查点的层数,最大记忆长度和最大序列长度信息
        self.checkpoint_activations = checkpoint_activations
        self.checkpoint_num_layers = checkpoint_num_layers
        self.max_memory_length = max_memory_length
        self.max_sequence_length = max_sequence_length

(3)定义输出层初始化方法

由use_scaled_init_for_output_weights决定,若为False则不进行初始化缩放,若为true则调用scaled_init_method进行初始化

        #输出层初始化方法定义——由use_scaled_init_for_output_weights决定,若为False则不进行初始化缩放,若为true则调用scaled_init_method进行初始化
        output_layer_init_method = None
        if use_scaled_init_for_output_weights:
            output_layer_init_method = scaled_init_method(init_method_std,
                                                      num_layers)

scaled_init_method函数——返回初始化方法:初始权重呈均值为0,方差为init_method_std//sqrt(2*num_layers)的正态分布。

def scaled_init_method(sigma, num_layers):
    """Init method based on N(0, sigma/sqrt(2*num_layers)."""
    std = sigma / math.sqrt(2.0 * num_layers)
    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=std)

    return init_

(4)Position embedding

先进行嵌入层的dropout(防止过拟合),然后调用torch.nn.Embedding()方法按词典大小max_sequence_length和嵌入向量的维度hidden_size来定义词向量格式,然后将词向量的值初始化为呈以0为均值,以init_method_std为方差的正态分布。

        # Embeddings dropout嵌入层dropout
        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)

        # Position embedding (serial).初始化含位置信息的词向量方法
        self.position_embeddings = torch.nn.Embedding(max_sequence_length,
                                                        hidden_size)#随机以max_sequence_length为词典的大小(词的个数),以hidden_size来嵌入向量的维度(即用多少维来表示一个符号)初始化词向量,默认词向量值在正态分布N(0,1)中随机取值
        # Initialize the position embeddings.词向量值在正态分布N(0,init_method_std)中随机取值
        torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std)

(5)窗口定义

        self.query_window = query_window
        self.key_window_times = key_window_times
        self.num_pivot = num_pivot

(6)Transformer layers设置

首先定义了一个get_layer()函数来获得对应层id的网络层(transformer layer)

        #获得对应层id的网络层
        def get_layer(layer_id):
            return GPT2ParallelTransformerLayer(
                hidden_size,
                num_attention_heads,
                attention_dropout_prob,
                output_dropout_prob,
                layernorm_epsilon,
                unscaled_init_method(init_method_std),
                output_layer_init_method=output_layer_init_method,
                query_window=query_window,
                key_window_times=key_window_times,
                scale_normalization=True
                )

这里调用了GPT2ParallelTransformerLayer类

(7)将 num_layer 个 transformer layer打包在一起,以列表形式保存

        # Transformer layers.
        self.layers = torch.nn.ModuleList(
            [get_layer(layer_id) for layer_id in range(num_layers)])

(8)output层的LayerNorm处理

        # Final layer norm before output.
        self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)

(9)激活点检查

        if deepspeed.checkpointing.is_configured():
            global get_cuda_rng_tracker, checkpoint
            get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
            checkpoint = deepspeed.checkpointing.checkpoint
        self.rmask = None#是否进行稀疏处理

2、forward

    def forward(self, hidden_states, position_ids, attention_mask, txt_indices_bool, img_indices_bool, is_sparse=0, *mems):
        '''''
        hidden_states:输入的网络层;
        position_ids:位置编码;
        attention_mask;
        txt_indices_bool:选取文本token有效的索引矩阵
        img_indices_bool:选取图像token有效的索引矩阵
        is_sparse:是否稀疏处理,稀疏训练,稀疏推断
        mems:记忆模块;
        '''''

(1)获取最终的输入层的相关信息

获取b,s和最终的输入列数(hidden_states和记忆模块的concat的结果)

        batch_size, query_length = hidden_states.size()[:2]#获取batchsize(b)和读取的序列长度(s)
        memory_length = mems[0].size(1) if mems else 0#获取记忆模块的序列长度(模块列数)
        key_length = query_length + memory_length#得到最终的序列长度(类似concat维数增加)

(2)attention mask建立

最终shape[1,1,s,s](无记忆模块情况下,有记忆为[1,1,s,s+m],m为memory_length)

            # conventional transformer
            #建立常规transformer的attention mask
            def build_mask_matrix(query_length, key_length, sep):
                m = torch.ones((1, query_length, key_length), device=hidden_states.device, dtype=hidden_states.dtype)#初始化为全一矩阵
                assert query_length <= key_length
                m[0, :, -query_length:] = torch.tril(m[0, :, -query_length:])#返回m[0, :, -query_length:]区域(最后两维)是下三角矩阵的矩阵
                m[0, :, :sep + (key_length - query_length)] = 1#注意力标记
                m = m.unsqueeze(1)#[1,s,s+m]->[1,1,s,s+m]
                return m
            #生成attention_mask,无记忆模块是[1,1,s,s],有记忆是[1,1,s,s+m]
            attention_mask = build_mask_matrix(query_length, key_length, sep)

(3)稀疏训练or推断准备

✨获取稀疏训练的rmask

        #启用稀疏训练生成rmask
        if is_sparse == 1 and (self.rmask is None):
            w, times = self.query_window, self.key_window_times#滑动窗口大小+窗口数的减少量获取
            g = key_length // w#获取全局attention窗口个数
            tmp = torch.ones((g-times+1, w , w), device=hidden_states.device, dtype=hidden_states.dtype)#初始化rmask(可理解为g-times+1个窗口)
            tmp = torch.tril(1 - torch.block_diag(*tmp))#*将三维矩阵变成二维矩阵列表;torch.block_diag将g-times+1个w*w矩阵组合成一个块对角矩阵,1-使得中间块为0,其余为1;torch.tril返回下三角矩阵。shape为((g-times+1)*w,(g-times+1)*w)
            self.rmask = torch.nn.functional.pad(tmp, (0, (times-1)*w, (times-1)*w, 0)) # pad (left, right, top, bottom),这四个元素的位置代表了填充的位置,大小为填充的行数,默认填0,所以最终shape为(g*w,g*w),左下角为一个((g-times+1)*w,(g-times+1)*w)大小的下三角矩阵

✨获取左边界和支点

        if is_sparse == 2:#稀疏推断
            left_boundary = max(0, key_length - self.key_window_times * self.query_window)#获取左边界(将key_length分为n份query_window的块块,做除法后的余数部分为左边界
            window_idx = torch.arange(left_boundary, key_length, device=hidden_states.device, dtype=torch.long).expand(batch_size, -1)#torch.arange获得[left_boundary,...,key_length-1];expand(batch_size, -1)获得batchsize条[left_boundary,...,key_length-1],获得shape为(batchsize*key_length-left_boundary)
        elif is_sparse == 1:#稀疏训练
            left_boundary = key_length#获取左边界
            num_pivot = self.num_pivot#transformer里图像token和文本token的总和数量获取

✨选取每个batch中对应有效的index的image token和txt token

        #选取每个batch中对应有效的index的image token和txt token
        if is_sparse: # 1 or 2                
            # select out the real indices for sampling
            img_indices = [img_indices_bool[i][:left_boundary].nonzero(as_tuple=False).view(-1) for i in range(batch_size)]#.nonzero(as_tuple=False)取出非0元素的索引(即取出有效索引);.view(-1)将其展平
            txt_indices = [txt_indices_bool[i][:left_boundary].nonzero(as_tuple=False).view(-1) for i in range(batch_size)]

✨稀疏推断支点数目设定

        #稀疏推断支点数目设定(总token数量增加)
        if is_sparse == 2:
            ratio = self.num_pivot / self.max_sequence_length#支点比例获取
            max_text_num = max(len(text_idx) for text_idx in txt_indices)#获取batch中最长的有效文本token长度
            num_pivot = max_text_num + int((left_boundary - max_text_num) * ratio)#支点数目更新

(4)对输入层的处理

给输入层加入初始化的位置信息词向量并且进行dropout操作

        #对输入层的处理
        position_embeddings = self.position_embeddings(position_ids)#对位置信息position_ids进行词向量的初始化
        hidden_states = hidden_states + position_embeddings#输入层加入初始化的位置信息词向量
        hidden_states = self.embedding_dropout(hidden_states)#对输入层进行dropout

(5)这次是否有产生记忆模块

若拥有最大记忆长度,则产生的记忆模块是输入层,但不需要计算其梯度

        #这次是否有产生记忆模块

        if self.max_memory_length > 0:#若拥有最大记忆长度,
            mem_layers = [hidden_states.detach()]#记忆模块赋为输入层,但不需要计算其梯度
        else:#否则没有记忆模块
            mem_layers = []

然后保存一下attention mask

        attention_mask_saved = attention_mask#保存attention mask

(6)获取下一层的输入——分为是否采取检查点激活两种情况来分析

(都要利用get_layer来实现,所以都要先获取相应的参数输入才可调用)

✨采取检查点激活

①首先是必要的初始化和参数获取

            l = 0#初始化start层id
            num_layers = len(self.layers)#Transformer layers的数量获取
            chunk_length = self.checkpoint_num_layers#检查点的层数

循环获取层

            while l < num_layers:

②稀疏训练or推断情况下获取下一层的输入的参数

                if is_sparse > 0:#稀疏训练or推断

🌳获取pivot的索引(pivot即随机抽取的token,用于代表全局整幅图片)

                    # =====================   Pivot Mask   ======================== #
                    pivot_idx = torch.stack([
                        torch.cat((
                            text_idx,
                            img_indices[i][
                                torch.tensor(random.sample(range(len(img_indices[i])), k=num_pivot - len(text_idx)), dtype=torch.long, device=text_idx.device)
                            ]
                        ), dim=0)
                        for i, text_idx in enumerate(txt_indices)
                    ])
                    #首先由random.sample随机抽取(预设支点数量-该batch的有效文本token长度)=该batch的有效图像token长度个图像token索引,并且将文本token和图像token拼接在一起

🌳然后对于稀疏训练:获取pivot_attention_mask,进而获取输入所需的参数列表

                    if is_sparse == 1: # sparse training
                        assert key_length == query_length#断言最终的序列长度和读取的序列长度(s)是否相同
                        b, s = batch_size, key_length
                        pivot_attention_mask = self.rmask.expand(b, s, s).gather(dim=-1, index=pivot_idx.unsqueeze(1).expand(b, s, self.num_pivot))#生成针对随机选取的token的注意力矩阵——pivot attention mask
                        #expand()函数扩展维度,其余不变。
                        # 相当于先由b个原来的s*s(s=g*w)大小的rmask(即每个batch里都有rmask)拼成一个大小为(b,s,s)的矩阵;再由gather函数根据 index 参数(即是索引)返回矩阵里面对应位置的值(即挑出随机选中的token对应索引值的rmask值)——针对的是每个batch的s*s的rmask;最后再由expand函数展成大小为(b,s,随机选取的token数量)的矩阵
                        args = [hidden_states, pivot_attention_mask, pivot_idx, torch.tensor(is_sparse)]#参数列表记录

🌳然后对于稀疏推理:获取全部需要注意的token的idx,并形成参数列表

                    elif is_sparse == 2: # sparse inference
                        pw_idx = torch.cat((pivot_idx, window_idx), dim=-1)#获取随机选取的token的idx矩阵与额外标记注意的窗口的idx矩阵concat后的需要attention的idx矩阵
                        args = [hidden_states, attention_mask_saved, pw_idx, torch.tensor(is_sparse)]#参数列表记录

🌳错误提示

                    else:
                        raise NotImplementedError

③非稀疏处理情况下参数列表获取

                else:
                    args = [hidden_states, attention_mask_saved]#非稀疏处理的参数列表记录(输入层和attention mask)

④记忆模块对参数列表的补充

                #对于记忆模块的参数补充
                if mems:
                    args += mems[l: l + chunk_length]

⑤获得下一层的输入并进行检查,且start层idx(l)更新

                #检查点激活并得到下一层的输入层
                hidden_states = checkpoint(custom(l, l + chunk_length), *args)
                #start为第l层,end为第l + chunk_length层(共检查点层数数量)
                l += chunk_length#下一个检查点的开始层数

这里调用custom函数——用于获取下一层的输入

        def custom(start, end):
            def custom_forward(*inputs):
                layers_ = self.layers[start:end]#获取对应的层序列
                x_, inputs = inputs[0], inputs[1:]#将他们分成两份(头和其余)
                    
                if is_sparse > 0:#稀疏处理
                    inputs, mems_ = inputs[:3], inputs[3:]#输入为前3层,其余为记忆模块
                else:#不采取稀疏处理
                    inputs, mems_ = inputs[:1], inputs[1:]#输入为第1层,其余为记忆模块

                for i, layer in enumerate(layers_):
                    mem_i_ = mems_[i] if mems_ else None#获取第i层的记忆模块
                    x_ = layer(x_, *inputs, mem=mem_i_)#调用get_layer中GPT2ParallelTransformerLayer的forward——x_对应hidden_states(输入), inputs对应ltor_mask(attention mask)
                    if self.max_memory_length > 0:
                        mem_layers.append(x_.detach())#记忆模块添加(不参与梯度计算)
                return x_
            return custom_forward

✨不采取检查点激活

思路和上面的检查点激活类似,只是不考虑了检查点层数和checkpoint

        else:#不进行检查点激活
            assert is_sparse != 1, 'Please use checkpoint_activations for sparse attention training.'
            for i, layer in enumerate(self.layers):#遍历Transformer layers
                if is_sparse == 0:#非稀疏处理——获取下一步传入的参数列表
                    args = [hidden_states, attention_mask_saved]
                elif is_sparse == 2:#稀疏推断
                    pivot_idx = torch.stack([
                        torch.cat((
                            text_idx,
                            img_indices[i][
                                torch.tensor(random.sample(range(len(img_indices[i])), k=num_pivot - len(text_idx)), dtype=torch.long, device=text_idx.device)
                            ]
                        ), dim=0)
                        for i, text_idx in enumerate(txt_indices)
                    ])#首先由random.sample随机抽取(预设支点数量-该batch的有效文本token长度)=该batch的有效图像token长度个图像token索引,并且将文本token和图像token拼接在一起
                    pw_idx = torch.cat((pivot_idx, window_idx), dim=-1)#获取随机选取的token的idx矩阵与额外标记注意的窗口的idx矩阵concat后的需要attention的idx矩阵
                    args = [hidden_states, attention_mask_saved, pw_idx, torch.tensor(is_sparse)]#参数列表记录

                mem_i = mems[i] if mems else None#对应层的记忆模块
                hidden_states = layer(*args, mem=mem_i)#下一层的输入层获取
                if self.max_memory_length > 0:#记忆层添加
                    mem_layers.append(hidden_states.detach())

(7)最后一层norm

作Layernorm操作

        # Final layer norm.
        output = self.final_layernorm(hidden_states)#即对下一层的输入(这一层的输出)做一个LayerNorm规范

(8)记忆模块更新

        #更新记忆模块
        if self.max_memory_length > 0:
            mem_layers = self.update_mems(mem_layers, mems)

这里调用update_mems进行更新

    def update_mems(self, hiddens, mems):
        memory_length = mems[0].size(1) if mems else 0#原记忆模块的长度(列数)
        query_length = hiddens[0].size(1)#新待加入的记忆模块长度(列数)
        new_memory_length = min(self.max_memory_length, memory_length + query_length)#新的记忆模块的长度确定
        new_mems = []
        with torch.no_grad():
            for i in range(len(hiddens)):
                if new_memory_length <= query_length:#说明选中的是self.max_memory_length(记忆模块完全为新的记忆层组成)。取每一层的每一行的后new_memory_length组成新的记忆矩阵
                    new_mems.append(hiddens[i][:, -new_memory_length:])
                else:#说明选中的是memory_length + query_length。取原来的记忆模块和新加入的进行拼接(沿列拼接)
                    new_mems.append(torch.cat((mems[i][:, -new_memory_length+query_length:], hiddens[i]), dim=1))
        return new_mems

(9)返回这一层的输出结果和记忆模块

        return (output, *mem_layers)#返回下一层的输入(这一层的输出结果)和记忆模块

欢迎大家在评论区批评指正,谢谢~

  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

tt丫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值