笔记:BLIP源码之(2)模型是如何定义的

模型是怎么定义的:model之前的继承方式是怎么样的,用了什么api,论文里面的一个公式就调用了很多function

调用 blip_retrieval 这个函数,得到本论文用到的model,接下来需要一层一层剖析

model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], 
                             vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], 
                             queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'])

blip_retrieval 函数:

def blip_retrieval(pretrained='',**kwargs):
	# BLIP_Retrieval类传入参数后实例化的对象就是model
	# 创建  BLIP 模型实例
    model = BLIP_Retrieval(**kwargs)
    if pretrained:
    # 如果指定了预训练模型的路径,则调用 load_checkpoint 函数加载预训练模型
        model,msg = load_checkpoint(model,pretrained)
        print("missing keys:")
        # missing_keys 属性:在加载预训练模型时,模型中存在但在预训练模型文件中缺失的参数键
        print(msg.missing_keys)
    return model 

当加载预训练模型时,模型的参数通常以键值对的形式存储。每个键表示一个参数变量,对应的值表示该参数的具体数值。
如果缺失了键信息,可能出现了以下情况:模型结构的改变,部分参数未被保存,预训练模型文件损坏等。
打印出缺失的键信息可以帮助我们了解哪些参数键在加载过程中无法获取到对应的数值。这样的信息可能对模型的进一步使用、调试或修复是有帮助的。

接下来看 BLIP_Retrieval类的代码实现:

首先给出retrieval_coco.yaml 中和BLIP_Retrieval 类有关的配置参数:

# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
vit: 'base'
# 
vit_grad_ckpt: True
vit_ckpt_layer: 4
image_size: 384
queue_size: 57600
negative_all_rank: True

论文原文:
在这里插入图片描述

根据原文可知,self.visual_encoder就是 image encoder,作用是:把一张输入图像划分成很多patches,并且编码他们成为一个embeddings的序列,还要加入一个 cls token 来表示全局图像特征

class BLIP_Retrieval(nn.Module):
    def __init__(self,                 
                 med_config = 'configs/med_config.json',  
                 image_size = 384,
                 vit = 'base',
                 vit_grad_ckpt = False,
                 vit_ckpt_layer = 0,                      
                 embed_dim = 256,     
                 queue_size = 57600,
                 momentum = 0.995,
                 negative_all_rank = False,
                 ):
        # 在构造函数中,通过调用父类 nn.Module 的构造函数 super().__init__() 
        # 来确保父类的初始化操作被正确执行           
        super().__init__()
        self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
        self.tokenizer = init_tokenizer() 

1. visual encoder

create_vit函数的定义如下:

def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
        
    assert vit in ['base', 'large'], "vit parameter must be base or large"
    
    if vit=='base':
        vision_width = 768
        visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, 
                                           num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
                                           drop_path_rate=0 or drop_path_rate
                                          )   
    ''' 因为配置的vit对应的值是base,所以省略了large的代码 '''
    # 返回视觉编码器实例,以及 vision_width
    return visual_encoder, vision_width

接下来还要找 VisionTransformer类的代码实现:

class VisionTransformer(nn.Module):

    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 
                 use_grad_checkpointing=False, ckpt_layer=0):
        super().__init__()
        # 模型的特征数量 和 嵌入维度 是一样的
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        # 创建了一个偏函数,该偏函数将 nn.LayerNorm 类作为函数,同时固定了 eps 参数的值为 1e-6
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        
        # PatchEmbed 类是用于将输入图像切分成多个大小相等的图像块并进行嵌入的操作
        # 每个patches都会被转换为一个嵌入向量,该嵌入向量表示了该图像块的特征
        # 即 img_size=384,patch_size = 16,总共均匀分成 384/16 = 24,这24个图像被转成embed_size = 768(16x16x3)
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)

        # 384/16 = 24
        num_patches = self.patch_embed.num_patches

        #  "CLS" 标记,代表 "classification"。它的目的是为了让模型能够在处理图像时同时学习到全局信息
        #  self.cls_token 是一个可学习的参数,它的值将在训练过程中通过反向传播进行更新,以适应特定的任务和数据。
        #  通过这种方式,模型可以在学习过程中适应不同的图像分类、检测、生成等任务,并捕捉到全局信息的重要性
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # 表示位置编码(Positional Encoding),用于为每个图像块(包括 "CLS" 标记)提供位置信息。
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        # 在位置编码后对其进行随机丢弃操作,以减少过拟合
        # drop_rate 是丢弃概率,控制要丢弃的元素比例
        self.pos_drop = nn.Dropout(p=drop_rate)
        
        # drop_path_rate = 0,depth = 12
        # torch.linspace(0, drop_path_rate, depth) 创建了一个张量,其中包含了从 0 到 drop_path_rate 之间的
        # depth 个均匀间隔的值,将张量中的每个值转换为 Python 数值,并将这些数值存储在列表 dpr 中
        # Drop Path 应用于 transformer encoder的attention 和 多层感知机(MLP)中
        # 在下文的block类中可以看到
        # 通过随机丢弃一部分连接来增加模型的泛化能力
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule

		# block就是 Transformer Encoder
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
                # 当前自注意力层是最后几层之一,需要应用梯度检查点技术,i>= 12-4
                use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
            )
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        # 截断正态分布是一种正态分布的变体,它将生成的值限制在一定的范围内,以避免生成过大或过小的值。
        # 这里的 std=.02 表示生成的值的标准差为 0.02
        trunc_normal_(self.pos_embed, std=.02) # 使用截断正态分布初始化 self.pos_embed 参数
        trunc_normal_(self.cls_token, std=.02) # 截断正态分布初始化 self.cls_token 参数
        # 对模型进行初始化
        self.apply(self._init_weights)

    # 是一个模型的初始化方法。它通过遍历模型的所有模块,对线性层和层归一化(LayerNorm)层进行特定的参数初始化
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                # 使用常数初始化
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0) # 将偏置 m.bias 的值设为 0
            nn.init.constant_(m.weight, 1.0)
            
     def no_weight_decay(self):
        # 指定不需要进行权重衰减(weight decay)的参数
        return {'pos_embed', 'cls_token'}
     def forward(self, x, register_blk=-1):
        B = x.shape[0] # x:(B,N,D)
        x = self.patch_embed(x)

        # position of cls: (B, 0, D)
        # self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # expand 是用来扩展维度的,在第一个维度上复制 B 次,保持第二维度和第三维度不变
        # 因此,cls_token的shape从(1,1,768)变成(B,1,D)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        # 在第一个维度进行拼接,x变成 (B,N+1,D)
        x = torch.cat((cls_tokens, x), dim=1)
        # 每个输入张量中的补丁位置添加位置信息
        x = x + self.pos_embed[:,:x.size(1),:]
        x = self.pos_drop(x) # 至此,得到了 Transformer Encoder的输入

        # x会进入 Transformer Encoder
        for i,blk in enumerate(self.blocks):
            x = blk(x, register_blk==i)

        # Transformer Encoder的输出 会再经过 norm
        x = self.norm(x)
        # x.shape = (B, N+1, D),N:num of patches D:dimension of a patch
        return x
	# 加载预训练模型的参数
    def load_pretrained(self, checkpoint_path, prefix=''):
    """ Load weights from .npz checkpoints for official Google Brain Flax implementation
    """
        _load_weights(self, checkpoint_path, prefix)
        

因为调用了Block类,所以附上论文的图,注意,这个Block定义的就是框出来的部分:先是layer Norm,再是Attention,再是residual,之后又接上Norm,再有residual,即:residual在每一块结束,Norm在每一块开始前:

在这里插入图片描述
贴上代码:

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
        super().__init__()
        
        self.norm1 = norm_layer(dim)
        # Attention模块中 multi-head后还跟了project,因此,既有attn_drop,还有proj_drop
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        # 如果drop_path <= 0.则不做操作
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        # 第2个Norm
        self.norm2 = norm_layer(dim)
        # mlp_ratio = 4, MLP 比例参数,表示隐藏层维度相对于输入维度的比例
        # mlp_hidden_dim :768 x 4
        mlp_hidden_dim = int(dim * mlp_ratio)
        # mlp层
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
		
		# use_grad_checkpointing  = true,则会对自注意力模块和 MLP 模块进行梯度检查点封装
        if use_grad_checkpointing: # 是否使用梯度检查点技术
            # 对attention和mlp进行梯度检查点封装
            self.attn = checkpoint_wrapper(self.attn)
            self.mlp = checkpoint_wrapper(self.mlp)

    def forward(self, x, register_hook=False):
        # 经过attentin和mlp后都进行一下 drop path 再进行residual add
        x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

梯度检查点技术的目的是为了减少计算和内存消耗,特别是在模型中存在大量的计算图时。通过使用梯度检查点技术,可以将计算图中的一部分操作在前向传播时计算并保存,而在反向传播时只需计算梯度,从而减少内存占用和计算时间。
通过对自注意力模块和 MLP 模块应用梯度检查点封装,可以在一定程度上优化模型的计算和内存消耗

Block中又调用了Attention类,注意:多头注意力后面还跟了一个线性层:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads # 768//12 = 64
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5 # 64 ** -0.5 = 1/8
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_gradients = None
        self.attention_map = None
'''省略部分代码'''
    
    def forward(self, x, register_hook=False):
        B, N, C = x.shape # batch_size N(patches) Dimension(768)
        # 经过self.qkv(x)得到 (B N 768*3) ,reshape 之后成 (B,N,3,12,64),再permute成(3,B,12,N,64)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # q,k,v的shape:(B,12,N,64)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
        # q:(B,12,N,64) k.transpose(-2, -1):(B,12,64,N) 得到attn的shape:(B,12,N,N)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
                
        if register_hook:
            # 保存注意力图和注意力梯度
            self.save_attention_map(attn)
            # 通过 register_hook 方法将 self.save_attn_gradients 函数
            # 注册为注意力权重张量 attn 的梯度钩子函数。这样,在计算注意力
            # 权重的梯度时,钩子函数将被调用并执行自定义的操作,例如保存梯度值或进行其他处理。
            attn.register_hook(self.save_attn_gradients)        
        # (B,12,N,N) * (B,12,N,64) =  (B,12,N,64) transpose后:(B,N,12,64) ,reshape 后:(B,N,768)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        # 接了一个线性层
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

register_hook 是 PyTorch 中的一个方法,用于注册一个钩子函数(hook function)到张量上。钩子函数可以在张量的梯度计算过程中执行自定义操作,例如记录梯度、修改梯度、分析梯度等。通过注册钩子函数,可以在模型的前向传播和反向传播过程中,对张量的值或梯度进行监控、记录和分析,以实现一些特定的需求,如可视化、调试、梯度修正等。

Block中还调用了Mlp类:

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
        原论文说的:The MLP contains two layers with a GELU non-linearity.
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        # in_features = 768 ,hidden_features = 768 x 4
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        # 非线性层1
        self.fc1 = nn.Linear(in_features, hidden_features)
        # 激活函数,用的是 nn.GELU
        self.act = act_layer()
        # # 非线性层2
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

2. multimodal mixture of encoder-decoder (MED)

class BLIP_Retrieval(nn.Module):
   '''忽略部分代码'''
        self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
        '''之前第一点把visual_encoder 相关代码做记录了,现在往后看代码'''
        # Tokenizer(分词器)用于将输入的文本分割成单词、子词或字符等更小的单位,
        # 以便进行后续的处理和编码
        self.tokenizer = init_tokenizer() 
        # 导入med的配置文件
        med_config = BertConfig.from_json_file(med_config)
        # 更改配置文件:令 MED的encoder_width  等于  vision_width(768,因为选的是base)
        med_config.encoder_width = vision_width
        # Text Encoder(文本编码器)是指将分词后的文本转换为向量表示的模型或组件
        # 文本编码器可以是基于预训练模型的深度神经网络(如 BERT、GPT 等),
        # 也可以是其他常用的编码模型(如 Word2Vec、GloVe 等),本论文是自定义了 BertModel类:
        self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)   

调用了自定义的init_tokenizer函数,但是主要还是调用了BertTokenizer,只是额外加入了special_tokens

# BERT tokenizer
def init_tokenizer():
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    # 表示解码器的开始位置
    tokenizer.add_special_tokens({'bos_token':'[DEC]'})
    # 表示编码器的开始位置
    tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})       
    tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]  
    return tokenizer

因为还调用了自定义的BertModel类,附上原论文内容和相关代码:

Image-grounded text encoder, which injects visual information by inserting one additional cross-attention (CA) layer between the self-attention (SA) layer and the feed forward network (FFN) for each transformer block of the text encoder.(这和原始transformer论文中的decoder一样)。
A task-specific [Encode] token is appended to the text, and the output embedding of [Encode] is used as the multimodal representation of the image-text pair.换句话说,[Encode]" token 的嵌入包含了图像和文本的融合信息,可以作为图像-文本对的表示

Image-grounded text decoder, which replaces the bidirectional self-attention layers in the image-grounded text encoder with causal self-attention layers. (这意味着解码器在生成序列时只能依赖当前位置之前的信息,不会引入未来信息的依赖)。
A [Decode] token is used to signal the beginning of a sequence, and an end-of-sequence token is used to signal its end.(为了指示序列的开始,使用了一个特殊的 “[Decode]” token,而使用一个特殊的序列结束标记(end-of-sequence token)来标识序列的结束)

class BertModel(BertPreTrainedModel):
    """
    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
    cross-attention is added between the self-attention layers, following the architecture described in `Attention is all you need 
    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an  input to the forward pass.
    """

    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config) if add_pooling_layer else None
        self.init_weights()

BertModel继承自BertPreTrainedModel,附上代码:

class BertPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = BertConfig
    base_model_prefix = "bert"
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

还调用了很多其他的类来一起组成这个BertModel,首先是BertEmbeddings

class BertEmbeddings(nn.Module):
    """Construct the embeddings from word and position embeddings."""

    def __init__(self, config):
        super().__init__()
        # pad_token_id": 0
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        # config.layer_norm_eps:归一化操作的小数精度
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        # 注册一个缓冲区(buffer)的张量,并命名为 "position_ids"
        # 这个缓冲区被注册后,在模型被序列化时,其数据将被导出并保存下来,以便在加载模型时重新使用
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
        # 绝对位置编码
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

        self.config = config

    def forward(
        self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
    ):
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            # [:-1]:保留除最后一个维度外的所有维度
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            # position_ids (1, len position emb)
            # 从 past_key_values_length 开始,到 seq_length + past_key_values_length - 1 结束的位置 做嵌入
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

        embeddings = inputs_embeds

        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            embeddings += position_embeddings
        # word embedding + position embedding + LayerNorm+dropout
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

还有BertEncoder,由12个BertLayer组成:

class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # 由12个BertLayer组成
        self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
        mode='multimodal',
    ):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        next_decoder_cache = () if use_cache else None
               
        for i in range(self.config.num_hidden_layers):
            layer_module = self.layer[i]
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:

                if use_cache:
                    logger.warn(
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, past_key_value, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    mode=mode,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                    mode=mode,
                )

            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

BertEncoder又调用了 BertLayer

class BertLayer(nn.Module):
    # 1、先经过self-attention、线性层、LayerNorm
    # 2、再经过cross-attention(可选)、 线性层、LayerNorm
    # 3、先经过线性层,把 hidden_size 映射到 intermediate_size ,再从intermediate_size 映射到 hidden_size
    # 4、再经过 intermediate_size 到 hidden_size的映射,再经过LayerNorm
    def __init__(self, config, layer_num):
        super().__init__()
        self.config = config
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = BertAttention(config)      
        self.layer_num = layer_num          
        if self.config.add_cross_attention:
            self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
        mode=None,
    ):
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        attention_output = self_attention_outputs[0]

        outputs = self_attention_outputs[1:-1]
        present_key_value = self_attention_outputs[-1]

        # 多模态模式 需要cross attention
        if mode=='multimodal':
            assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"

            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                output_attentions=output_attentions,
            )
            attention_output = cross_attention_outputs[0]
            outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights                               
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        outputs = (layer_output,) + outputs

        outputs = outputs + (present_key_value,)

        return outputs

    def feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output

BertLayer又调用了BertAttentionBertIntermediateBertoutput,先贴上BertAttention的代码:

class BertAttention(nn.Module):
    def __init__(self, config, is_cross_attention=False):
        super().__init__()
        self.self = BertSelfAttention(config, is_cross_attention)
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

    # 剪枝注意力头
    # 剪枝是一种模型压缩技术,用于减少模型的大小和计算开销
    # 该方法接受一个heads参数,表示要剪枝的注意力头的索引列表
    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        # self是自注意力层
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        # 再经过output层
        attention_output = self.output(self_outputs[0], hidden_states)
        # self_outputs[1:]:attention_probs、past_key_value
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs # 返回:attention_output、attention_probs、past_key_value

BertAttention又分别调用了 BertSelfAttentionBertSelfOutput
先贴上 BertSelfAttention代码,通过 is_cross_attention来决定 是原始transformer论文中的encoder还是decoder,如果为true,则输入是 encoder_hidden_states,也就相当于Transformer中的decoder,如果是false,输入是hidden_states,则相当于Transformer中的encoder。

class BertSelfAttention(nn.Module):
    def __init__(self, config, is_cross_attention):
        super().__init__()
        self.config = config
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )
        # config.num_attention_heads = 12
        self.num_attention_heads = config.num_attention_heads
        # 768/12 = 64
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        # self.all_head_size = 12 * 64 = 768
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        # 线性层:768 -> 768
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        # 如果设置了 需要 cross_attention
        if is_cross_attention:
            # config.encoder_width:768
            self.key = nn.Linear(config.encoder_width, self.all_head_size)
            self.value = nn.Linear(config.encoder_width, self.all_head_size)
        else:
            # config.hidden_size = 768
            self.key = nn.Linear(config.hidden_size, self.all_head_size)
            self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
        self.save_attention = False   
            
   '''省略部分代码'''
    # 改变形状
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention:
            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
        elif past_key_value is not None:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        else:
            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        query_layer = self.transpose_for_scores(mixed_query_layer)

        past_key_value = (key_layer, value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        
        if is_cross_attention and self.save_attention:
            self.save_attention_map(attention_probs)
            attention_probs.register_hook(self.save_attn_gradients)         

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs_dropped = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs_dropped = attention_probs_dropped * head_mask

        context_layer = torch.matmul(attention_probs_dropped, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        outputs = outputs + (past_key_value,)
        return outputs

再贴上BertSelfOutput的代码:

class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
    	# 全联接层 + dropout +LayerNorm+残差链接
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        # 残差链接
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

附上BertIntermediate的代码:

class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 先经过线性层,把 hidden_size 映射到 intermediate_size
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        # 再经过激活层
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        # 得到中间层的隐藏状态
        return hidden_states

再贴上BertOutput的代码:

class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
    	# 全连接层 + dropou+LayerNorm+残差链接
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

总结:

BertEncoder 调用了 BertLayer, BertLayer调用了 BertAttention、BertIntermediate、BertOutput,其中BertAttention 又调用了 BertSelfAttention、BertOutput

再回到BertModel这个类:

BertPooler的代码:

class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()
	
	# 只对第一个token做pooling
    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值