前情回顾
在前面的章节里,我们实现了视觉transformer编码器模型,也对输入的prompt和图像数据进行了预处理,现在我们针对每一个数据对,对于prompt我们能获得其拆解之后的token id向量,而对于图像,我们能获得num_patch个图像块的编码嵌入向量。
如下图所示:
然而,我们需要为后面的语言模型构建输入,故首先需要对每个token进行向量嵌入编码,除此之外,我们需要通过一个projector来使得图像块的编码维度与token一致,然后把它们组装起来,还要为后面的过程提供位置id向量以及attention掩码向量。
下图显示了主模型的架构以及需要做的事情,所以我们将根据主模型的架构来设计模型。
主模型的上层搭建
首先贴出部分代码:
class PaliGemmaForConditionalGeneration(nn.Module):
def __init__(self,config:PaliGemmaConfig):
super().__init__()
self.config = config
self.vision_tower = SiglipVisionModel(config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.vocab_size
self.language_model = GemmaForCausalLM(config.text_config)
self.pad_token_id = config.pad_token_id
def tie_weights(self):
self.language_model.tie_weights()
我们的主模型首先需要一个全局config,因为是全局的config,所以内部会有视觉模型部分的config以及语言模型部分的config还有一些其他配置,首先我们先定义好语言模型配置类:
class GemmaConfig():
def __init__(
self,
vocab_size,
hidden_size,
intermediate_size,
num_hidden_layers,
num_attention_heads,
num_key_value_heads,
head_dim = 256,
max_position_embeddings = 8192,
rms_norm_eps = 1e-6,
rope_theta = 10000.0,
attention_bias = False,
attention_dropout = 0.0,
pad_token_id = None,
**kwargs
):
super().__init__()
self.vocab_size = vocab_size ##词汇表大小
self.hidden_size = hidden_size ##嵌入维度
self.intermediate_size = intermediate_size ##MLP中间层维度
self.num_hidden_layers = num_hidden_layers ##注意力层数
self.num_attention_heads = num_attention_heads ##注意力头数
self.num_key_value_heads = num_key_value_heads ##组注意力相关
self.head_dim = head_dim ##每个head的维度
self.max_position_embeddings = max_position_embeddings ##旋转位置编码相关
self.rms_norm_eps = rms_norm_eps ##均方根归一化相关
self.rope_theta = rope_theta ##旋转位置编码相关
self.attention_bias = attention_bias #attention bias
self.attention_dropout = attention_dropout ##注意力丢弃机制
self.pad_token_id = pad_token_id ##填充符id
其次是主模型的配置,主模型的配置里面集成了语言模型配置,视觉编码器配置,和主模型运行相关的配置。
class PaliGemmaConfig():
def __init__(
self,
vision_config = None,
text_config = None,
vocab_size:int = 257152,
ignore_idx = -100,
projection_dim = 2048,
image_token_index = 256000,
hidden_size = 2048,
pad_token_id:int = None,
**kwargs
):
super().__init__()
self.vision_config = SiglipVisionConfig(**vision_config)
self.text_config = GemmaConfig(**text_config,pad_token_id= pad_token_id)
self.vocab_size = vocab_size
self.ignore_idx = ignore_idx
self.projection_dim = projection_dim
self.image_token_index = image_token_index
self.hidden_size = hidden_size
self.pad_token_id = pad_token_id
self.is_encoder_decoder = False
self.text_config.vocab_size = vocab_size
self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2 ##每个图像生成的token数量
self.vision_config.projection_dim = self.projection_dim ## 图像嵌入投影到token嵌入的投影器维度
图像的token投影矩阵
由于图像对每个patch的编码维度不一定与token的嵌入维度相同,我们需要引入一个PaliGemmaMultiModalProjector。
这个投影器本质上就是一个线性MLP层,其实现代码如下:
class PaliGemmaMultiModalProjector(nn.Module): ##匹配
def __init__(self,config:PaliGemmaConfig):
super().__init__()
self.config = config
self.linear = nn.Linear(config.vision_config.hidden_size,config.projection_dim)
##将图像嵌入的每个patch的embedding投影到文本嵌入相同的维度
## [Batch_size, Num_image_tokens, Image_Embedding_dim] -> [Batch_size, Num_image_tokens, Hidden_size]
def forward(self,image_features:torch.Tensor):
hidden_states = self.linear(image_features)
return hidden_states
除了投影器,我们还需要为主模型搭建语言模型GemmaForCausalLM,但我们将在后面的章节实现这些。
参数捆绑
参数捆绑机制是指将模型的两个不同层的参数捆绑,使得它们的参数一致,这样可以减少模型需要训练的参数量,前提是这两个层的功能必须有关联,这样的参数才有实际意义。
def tie_weights(self):
self.language_model.tie_weights()
这里实现的参数捆绑实际上是将token->embedding 的模型层与模型输出的embedding -> 每个token概率的模型层进行捆绑,因为这两者是逆过程,所以功能有所关联,参数捆绑是有意义的。
前向过程
主模型的前向过程如下:
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
**kwargs
)-> Tuple:
assert torch.all(attention_mask == 1), 'attention_mask must be 1 beacuse we do not use padding'
## [Batch_size, Seq_len, Hidden_size]
input_embeds = self.language_model.get_input_embeddings()(input_ids)
## [Batch_size, Channels, Height, Width] --> [Batch_size, Patch_size, Embed_dim]
selected_image_feature = self.vision_tower(pixel_values.to(input_embeds.dtype))
##注意,这里的Embed_dim的维度可能与Hidden_size不同,需要做一下转换
image_features = self.multi_modal_projector(selected_image_feature)
input_embeds,attention_mask,position_ids = self._merge_inputs_id_with_image_features(input_embeds=input_embeds,input_ids=input_ids,image_features=image_features,kv_cache=kv_cache,attention_mask=attention_mask)
outputs = self.language_model(
inputs_embeds = input_embeds,
position_ids = position_ids,
attention_mask = attention_mask,
kv_cache = kv_cache,
**kwargs
)
return outputs
这里的输入是之前经过数据预处理模块得到的输出,然后我们先对每个token_id进行嵌入,同时对每个图像进行嵌入,并将图像的patch嵌入维度与token的嵌入维度对齐,最后将它们组装成语言模型的输入,并传入language_model进行推理,返回输出的token概率分布。这里的关键在于输入构建函数的实现。
输入构建
def _merge_inputs_id_with_image_features(
self, image_features: torch.Tensor,
input_embeds: torch.Tensor,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
):
_,_,embed_dim = image_features.shape
batch_size,seq_len = input_ids.shape ##不同的输入的seq_len是不同的,这就是为什么之前要限制一次只能载入一张图片和文本,否则要做填充的工作
dtype,device = input_embeds.dtype,input_embeds.device
#shape [Batch_size, Seq_len, Hidden_size]
scaled_image_features = image_features * (embed_dim ** -0.5)
##构建最终返回的embeddings
final_embeddings = torch.zeros(batch_size,seq_len,embed_dim,dtype=dtype,device=device)
## [Batch_size, Seq_len] 形如下:
## [
## Batch_1 [True,True,True,....False,False]
## Batch_2 [True,True,True,....False,False]
## ...
## Batch_n [True,True,True,....False,False]
## ]
text_mask = ( input_ids != self.config.image_token_index ) & ( input_ids != self.pad_token_id ) ## 只有text部分为True,其他为False
image_mask = ( input_ids == self.config.image_token_index )
padding_mask = ( input_ids == self.pad_token_id )
## 目前只有两个维度,没有embedding维度,先扩展一个维度,方便后续的mask操作
## [Batch_size, Seq_len, 1] 增广一维
text_mask = text_mask.unsqueeze(-1)
image_mask = image_mask.unsqueeze(-1)
padding_mask = padding_mask.unsqueeze(-1)
## expand函数会将指定的size为1的维度扩展到指定数量的维度,扩展方式是复制
## 比如 [[1],[1],[1]]的shape为(3,1),现在用expand将其扩展为(-1,3)那么此时的输出是[[1,1,1],[1,1,1],[1,1,1]]
## 此时需要把mask的embedding维度扩展到与input_embeds相同
text_mask = text_mask.expand(-1,-1,embed_dim)
image_mask = image_mask.expand(-1,-1,embed_dim)
padding_mask = padding_mask.expand(-1,-1,embed_dim)
## where函数根据条件进行替换,如果mask为True,那么替换为input_embeds,否则替换为final_embeddings
final_embeddings = torch.where(text_mask,input_embeds,final_embeddings)
## 这里也是依据image_mask选择是True的位置进行替换,之所以不能用where是因为image_feature的形状与mask的形状不匹配,image_mask函数只会关注那些mask为True的位置
final_embeddings = final_embeddings.masked_scatter(image_mask,scaled_image_features)
## padding的替换是把所有padding位置的嵌入向量变成0向量
final_embeddings = torch.where(padding_mask,torch.zeros_like(final_embeddings),final_embeddings)
dtype,device = dtype,device
q = input_embeds.shape[1]
min_dtype = torch.finfo(dtype).min
if kv_cache is None or kv_cache.num_items() == 0: ##表明此时kv_cache是空的,说明是在预载入prompt阶段,此时根据论文可以不用屏蔽任何token,因为屏蔽了也没用,模型只会取最后一个token的embedding来预测下一个token,而最后一个embedding要求看到之前左右的token
causal_mask = torch.full(
(batch_size,q,q),
fill_value= 0,
dtype=dtype,
device=device,
)
else: ##表明此时已经在推理过程中了,推理一次只会生成一个token,q == 1,而且推理过程中依据最新的这个token进行预测,所以也不需要屏蔽掉之前的token,需要屏蔽的置为负无穷,不需要屏蔽的置为0
assert q == 1
kv_len = kv_cache.num_items() + q
causal_mask = torch.full(
(batch_size,q,kv_len),
fill_value= 0,
dtype=dtype,
device=device,
)
## causal_mask的形状为[Batch_size, Q_len, KV_len]
## 需要增广一个多头注意力维度 [Batch_size, 1, Q_len, KV_len]
causal_mask = causal_mask.unsqueeze(1)
##attention_mask的形状为[Batch_size, Seq_len] 其中,1表示有效token,0表示padding token
if kv_cache is not None and kv_cache.num_items() > 0:
position_ids = attention_mask.cumsum(dim=-1) ##沿着最后一个维度计算sum 此时position_ids的形状为[Batch_size, Seq_len]
##因为在推理过程中一次只会有一个q,所以,q的position_id是position_ids每一个batch中最后的那个sum的值
position_ids = position_ids[:,-1]
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0) ##加一个batch维度
else:
##全量forward时,每一个token都要加上position_id
position_ids = (attention_mask.cumsum(dim=-1).masked_fill((attention_mask == 0),1).to(device)) ##masked_fill将条件为真的位置的值替换为1,attention_mask: [Batch_size, Seq_len]
return final_embeddings,causal_mask,position_ids
这里首先来解读一下函数的输入:
- input_embeds: 代表每个token的嵌入向量,形状为[B,Seq,Embedding_size]
- input_ids: 代表每个token的id,形状为[B,Seq]
- attention_mask: 标识每个token是有效的还是填充的,形状为[B,Seq]
- kv_cache: 用来指示当前推理过程中已经有了几个token了
在函数的初期,我们首先将图像特征统一除以了sqrt(embedding size)
scaled_image_features = image_features * (embed_dim ** -0.5)
这是因为后面会需要将所有的特征(图像+文本)乘以sqrt(embedding size)来让特征的方差趋近于1,但是图像特征的方差一般已经在1周围了,如果再去做乘积,就会使得图像的方差变的很大,两种模态的特征方差就不平衡了,所以这里提前做一些补偿。
随后我们准备了一个空的输入特征,来喂给Gemma 模型,此时全部用0填充。
final_embeddings = torch.zeros(batch_size,seq_len,embed_dim,dtype=dtype,device=device)
然后我们根据输入的token id,构建了用于区分image token,text token和padding token的三张掩码。并且我们让掩码的大小与输入特征的大小一致,这样可以方便final embedding的替换。
随后我们利用pytorch的scatter_mask函数和where函数来填充对应的特征
## where函数根据条件进行替换,如果mask为True,那么替换为input_embeds,否则替换为final_embeddings
final_embeddings = torch.where(text_mask,input_embeds,final_embeddings)
## 这里也是依据image_mask选择是True的位置进行替换,之所以不能用where是因为image_feature的形状与mask的形状不匹配,image_mask函数只会关注那些mask为True的位置
final_embeddings = final_embeddings.masked_scatter(image_mask,scaled_image_features)
## padding的替换是把所有padding位置的嵌入向量变成0向量
final_embeddings = torch.where(padding_mask,torch.zeros_like(final_embeddings),final_embeddings)
至此,输入准备好了,但我们还需要准备模型后面会要用的注意力掩码和位置id。
注意力掩码可以根据kv cache中的信息来知道当前的token数量,从而构建注意力掩码矩阵,需要屏蔽的位置设置为负无穷,不需要屏蔽的位置为0。但此时,由于依次只有一个token推理,所以推理过程中注意力掩码无需屏蔽任何token,而在预加载用户的prompt时,论文中也没有屏蔽任何token,这是论文的做法,不同的人有不同的实现而已。
如图,这里在[sep]之前都没有屏蔽任何已有的token,sep之后由于单token推理,所以也无需屏蔽token。于是我们构建了全0的注意力掩码矩阵。
if kv_cache is None or kv_cache.num_items() == 0: ##表明此时kv_cache是空的,说明是在预载入prompt阶段,此时根据论文可以不用屏蔽任何token,因为屏蔽了也没用,模型只会取最后一个token的embedding来预测下一个token,而最后一个embedding要求看到之前左右的token
causal_mask = torch.full(
(batch_size,q,q),
fill_value= 0,
dtype=dtype,
device=device,
)
else: ##表明此时已经在推理过程中了,推理一次只会生成一个token,q == 1,而且推理过程中依据最新的这个token进行预测,所以也不需要屏蔽掉之前的token,需要屏蔽的置为负无穷,不需要屏蔽的置为0
assert q == 1
kv_len = kv_cache.num_items() + q
causal_mask = torch.full(
(batch_size,q,kv_len),
fill_value= 0,
dtype=dtype,
device=device,
)
最后我们根据attention mask来计算position id,因为这里我们没有做padding操作,因为我们一开始就限制了一次推理只有一个图像文本对,所以attention mask是全1的,所以可以通过对用前缀和的方式计算attention mask,得到每一个token的位置id。
由于推理阶段只有一个token会被传递到后续进行attention计算 (这是因为kv cache已经记录了之前所有token的k向量和v向量,并且之前的token之间的qk计算是没有意义的,因为attention mask的存在,之前的一些token间的qk早在生成当前token之前已经做过了,所以无需将之前的token传递到后续阶段来计算attention),所以我们取postion_ids的最后一个值作为当前token的postion_id传入,以进行旋转位置编码。
每次只有新产生的那个token的kv计算是有效计算,其余token的kv计算是重复计算
position_ids = position_ids[:,-1]
而在全量forward(即读取用户输入的prompt并计算kv cache)的过程中,由于此时kv cache是空的,所以用户的prompt的所有token是一次性输入到模型中的,用来填充kv cache,所以要把所有的postion_ids都传入,并且确保position_ids有一个batch维度。
但注意,在全量forward的过程中,我们没有形成类似于上三角矩阵的这样的注意力掩码,因为这是论文选择的做法,仅此而已。
随后我们返回 模型输入嵌入,注意力掩码矩阵和位置id,位置id用于计算旋转位置编码:
return final_embeddings,causal_mask,position_ids