手搓多模态-07 主模型上层的搭建

前情回顾

在前面的章节里,我们实现了视觉transformer编码器模型,也对输入的prompt和图像数据进行了预处理,现在我们针对每一个数据对,对于prompt我们能获得其拆解之后的token id向量,而对于图像,我们能获得num_patch个图像块的编码嵌入向量。

如下图所示:

然而,我们需要为后面的语言模型构建输入,故首先需要对每个token进行向量嵌入编码,除此之外,我们需要通过一个projector来使得图像块的编码维度与token一致,然后把它们组装起来,还要为后面的过程提供位置id向量以及attention掩码向量。

下图显示了主模型的架构以及需要做的事情,所以我们将根据主模型的架构来设计模型。

主模型的上层搭建

首先贴出部分代码:

class PaliGemmaForConditionalGeneration(nn.Module):
	
	def __init__(self,config:PaliGemmaConfig):
		super().__init__()
		self.config = config
		self.vision_tower = SiglipVisionModel(config.vision_config)
		self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
		self.vocab_size = config.vocab_size

		self.language_model = GemmaForCausalLM(config.text_config)

		self.pad_token_id = config.pad_token_id

	def tie_weights(self):
		self.language_model.tie_weights()

我们的主模型首先需要一个全局config,因为是全局的config,所以内部会有视觉模型部分的config以及语言模型部分的config还有一些其他配置,首先我们先定义好语言模型配置类:

class GemmaConfig(): 
	def __init__(
		self,
		vocab_size,
		hidden_size,
		intermediate_size,
		num_hidden_layers,
		num_attention_heads,
		num_key_value_heads,
		head_dim = 256,
		max_position_embeddings = 8192,
		rms_norm_eps = 1e-6,
		rope_theta = 10000.0,
		attention_bias = False,
		attention_dropout = 0.0,
		pad_token_id = None,
		**kwargs
	):
		super().__init__()
		self.vocab_size = vocab_size ##词汇表大小
		self.hidden_size = hidden_size ##嵌入维度
		self.intermediate_size = intermediate_size ##MLP中间层维度
		self.num_hidden_layers = num_hidden_layers ##注意力层数
		self.num_attention_heads = num_attention_heads ##注意力头数
		self.num_key_value_heads = num_key_value_heads ##组注意力相关
		self.head_dim = head_dim ##每个head的维度
		self.max_position_embeddings = max_position_embeddings ##旋转位置编码相关
		self.rms_norm_eps = rms_norm_eps ##均方根归一化相关
		self.rope_theta = rope_theta ##旋转位置编码相关
		self.attention_bias = attention_bias #attention bias
		self.attention_dropout = attention_dropout ##注意力丢弃机制
		self.pad_token_id = pad_token_id ##填充符id

其次是主模型的配置,主模型的配置里面集成了语言模型配置,视觉编码器配置,和主模型运行相关的配置。

class PaliGemmaConfig():
	def __init__(
		self,
		vision_config = None,
		text_config = None,
		vocab_size:int = 257152,
		ignore_idx = -100,
		projection_dim = 2048,
		image_token_index = 256000,
		hidden_size = 2048,
		pad_token_id:int = None,
		**kwargs
	):
		super().__init__()
		self.vision_config = SiglipVisionConfig(**vision_config)
		self.text_config = GemmaConfig(**text_config,pad_token_id= pad_token_id)

		self.vocab_size = vocab_size
		self.ignore_idx = ignore_idx
		self.projection_dim = projection_dim
		self.image_token_index = image_token_index
		self.hidden_size = hidden_size
		self.pad_token_id = pad_token_id
		self.is_encoder_decoder = False
		
		self.text_config.vocab_size = vocab_size
		self.text_config.num_image_tokens = (self.vision_config.image_size // self.vision_config.patch_size) ** 2 ##每个图像生成的token数量
		self.vision_config.projection_dim = self.projection_dim ## 图像嵌入投影到token嵌入的投影器维度

图像的token投影矩阵

由于图像对每个patch的编码维度不一定与token的嵌入维度相同,我们需要引入一个PaliGemmaMultiModalProjector。

这个投影器本质上就是一个线性MLP层,其实现代码如下:

class PaliGemmaMultiModalProjector(nn.Module): ##匹配
	def __init__(self,config:PaliGemmaConfig):
		super().__init__()
		self.config = config
		self.linear = nn.Linear(config.vision_config.hidden_size,config.projection_dim)

		##将图像嵌入的每个patch的embedding投影到文本嵌入相同的维度
		## [Batch_size, Num_image_tokens, Image_Embedding_dim] -> [Batch_size, Num_image_tokens, Hidden_size]
	def forward(self,image_features:torch.Tensor):
		hidden_states = self.linear(image_features)
		return hidden_states

除了投影器,我们还需要为主模型搭建语言模型GemmaForCausalLM,但我们将在后面的章节实现这些。

参数捆绑

参数捆绑机制是指将模型的两个不同层的参数捆绑,使得它们的参数一致,这样可以减少模型需要训练的参数量,前提是这两个层的功能必须有关联,这样的参数才有实际意义。

	def tie_weights(self):
		self.language_model.tie_weights()

这里实现的参数捆绑实际上是将token->embedding 的模型层与模型输出的embedding -> 每个token概率的模型层进行捆绑,因为这两者是逆过程,所以功能有所关联,参数捆绑是有意义的。

前向过程

主模型的前向过程如下:

	def forward(
		self,
		input_ids: torch.LongTensor = None,
		pixel_values: torch.FloatTensor = None,
		attention_mask: Optional[torch.Tensor] = None,
		kv_cache: Optional[KVCache] = None,
		**kwargs
	)-> Tuple:
		assert torch.all(attention_mask == 1), 'attention_mask must be 1 beacuse we do not use padding'

		## [Batch_size, Seq_len, Hidden_size]
		input_embeds = self.language_model.get_input_embeddings()(input_ids)

		## [Batch_size, Channels, Height, Width] --> [Batch_size, Patch_size, Embed_dim]
		selected_image_feature = self.vision_tower(pixel_values.to(input_embeds.dtype))

		##注意,这里的Embed_dim的维度可能与Hidden_size不同,需要做一下转换
		image_features = self.multi_modal_projector(selected_image_feature)

		input_embeds,attention_mask,position_ids = self._merge_inputs_id_with_image_features(input_embeds=input_embeds,input_ids=input_ids,image_features=image_features,kv_cache=kv_cache,attention_mask=attention_mask)

		outputs = self.language_model(
			inputs_embeds = input_embeds,
			position_ids = position_ids,
			attention_mask = attention_mask,
			kv_cache = kv_cache,
			**kwargs
		)

		return outputs

这里的输入是之前经过数据预处理模块得到的输出,然后我们先对每个token_id进行嵌入,同时对每个图像进行嵌入,并将图像的patch嵌入维度与token的嵌入维度对齐,最后将它们组装成语言模型的输入,并传入language_model进行推理,返回输出的token概率分布。这里的关键在于输入构建函数的实现。

输入构建

	def _merge_inputs_id_with_image_features(
		self, image_features: torch.Tensor,
		input_embeds: torch.Tensor,
		input_ids: torch.LongTensor,
		attention_mask: Optional[torch.Tensor] = None,
		kv_cache: Optional[KVCache] = None,
	):
		_,_,embed_dim = image_features.shape
		batch_size,seq_len = input_ids.shape ##不同的输入的seq_len是不同的,这就是为什么之前要限制一次只能载入一张图片和文本,否则要做填充的工作
		dtype,device = input_embeds.dtype,input_embeds.device

		#shape [Batch_size, Seq_len, Hidden_size]
		scaled_image_features = image_features * (embed_dim ** -0.5)

		##构建最终返回的embeddings
		final_embeddings = torch.zeros(batch_size,seq_len,embed_dim,dtype=dtype,device=device)


		## [Batch_size, Seq_len] 形如下:
		## 		   [
		## Batch_1 [True,True,True,....False,False]
		## Batch_2 [True,True,True,....False,False]
		##  					...
		## Batch_n [True,True,True,....False,False]
		## 		   									]
		text_mask = ( input_ids != self.config.image_token_index ) & ( input_ids != self.pad_token_id ) ## 只有text部分为True,其他为False
		image_mask = ( input_ids == self.config.image_token_index )
		padding_mask = ( input_ids == self.pad_token_id )

		## 目前只有两个维度,没有embedding维度,先扩展一个维度,方便后续的mask操作
		## [Batch_size, Seq_len, 1] 增广一维
		
		text_mask = text_mask.unsqueeze(-1)
		image_mask = image_mask.unsqueeze(-1)
		padding_mask = padding_mask.unsqueeze(-1)

		## expand函数会将指定的size为1的维度扩展到指定数量的维度,扩展方式是复制
		## 比如 [[1],[1],[1]]的shape为(3,1),现在用expand将其扩展为(-1,3)那么此时的输出是[[1,1,1],[1,1,1],[1,1,1]]
		## 此时需要把mask的embedding维度扩展到与input_embeds相同

		text_mask = text_mask.expand(-1,-1,embed_dim)
		image_mask = image_mask.expand(-1,-1,embed_dim)
		padding_mask = padding_mask.expand(-1,-1,embed_dim)

		## where函数根据条件进行替换,如果mask为True,那么替换为input_embeds,否则替换为final_embeddings
		final_embeddings = torch.where(text_mask,input_embeds,final_embeddings)

		## 这里也是依据image_mask选择是True的位置进行替换,之所以不能用where是因为image_feature的形状与mask的形状不匹配,image_mask函数只会关注那些mask为True的位置
		final_embeddings = final_embeddings.masked_scatter(image_mask,scaled_image_features)

		## padding的替换是把所有padding位置的嵌入向量变成0向量
		final_embeddings = torch.where(padding_mask,torch.zeros_like(final_embeddings),final_embeddings)


		dtype,device = dtype,device
		q = input_embeds.shape[1]
		min_dtype = torch.finfo(dtype).min

		if kv_cache is None or kv_cache.num_items() == 0: ##表明此时kv_cache是空的,说明是在预载入prompt阶段,此时根据论文可以不用屏蔽任何token,因为屏蔽了也没用,模型只会取最后一个token的embedding来预测下一个token,而最后一个embedding要求看到之前左右的token
			
			causal_mask = torch.full(
				(batch_size,q,q),
				fill_value= 0,
				dtype=dtype,
				device=device,
			)
		else: ##表明此时已经在推理过程中了,推理一次只会生成一个token,q == 1,而且推理过程中依据最新的这个token进行预测,所以也不需要屏蔽掉之前的token,需要屏蔽的置为负无穷,不需要屏蔽的置为0
			assert q == 1
			kv_len = kv_cache.num_items() + q
			causal_mask = torch.full(
				(batch_size,q,kv_len),
				fill_value= 0,
				dtype=dtype,
				device=device,
			)
		
		## causal_mask的形状为[Batch_size, Q_len, KV_len]
		## 需要增广一个多头注意力维度 [Batch_size, 1, Q_len, KV_len]
		causal_mask = causal_mask.unsqueeze(1)

		##attention_mask的形状为[Batch_size, Seq_len] 其中,1表示有效token,0表示padding token
		if kv_cache is not None and kv_cache.num_items() > 0:
			position_ids = attention_mask.cumsum(dim=-1) ##沿着最后一个维度计算sum 此时position_ids的形状为[Batch_size, Seq_len]
			##因为在推理过程中一次只会有一个q,所以,q的position_id是position_ids每一个batch中最后的那个sum的值
			position_ids = position_ids[:,-1]
			if position_ids.dim() == 1:
				position_ids = position_ids.unsqueeze(0) ##加一个batch维度
		else:
			##全量forward时,每一个token都要加上position_id
			position_ids = (attention_mask.cumsum(dim=-1).masked_fill((attention_mask == 0),1).to(device)) ##masked_fill将条件为真的位置的值替换为1,attention_mask: [Batch_size, Seq_len]

		return final_embeddings,causal_mask,position_ids

这里首先来解读一下函数的输入:

  • input_embeds: 代表每个token的嵌入向量,形状为[B,Seq,Embedding_size]
  • input_ids: 代表每个token的id,形状为[B,Seq]
  • attention_mask: 标识每个token是有效的还是填充的,形状为[B,Seq]
  • kv_cache: 用来指示当前推理过程中已经有了几个token了

在函数的初期,我们首先将图像特征统一除以了sqrt(embedding size)

scaled_image_features = image_features * (embed_dim ** -0.5)

这是因为后面会需要将所有的特征(图像+文本)乘以sqrt(embedding size)来让特征的方差趋近于1,但是图像特征的方差一般已经在1周围了,如果再去做乘积,就会使得图像的方差变的很大,两种模态的特征方差就不平衡了,所以这里提前做一些补偿。

随后我们准备了一个空的输入特征,来喂给Gemma 模型,此时全部用0填充。

final_embeddings = torch.zeros(batch_size,seq_len,embed_dim,dtype=dtype,device=device)

然后我们根据输入的token id,构建了用于区分image token,text token和padding token的三张掩码。并且我们让掩码的大小与输入特征的大小一致,这样可以方便final embedding的替换。

随后我们利用pytorch的scatter_mask函数和where函数来填充对应的特征

## where函数根据条件进行替换,如果mask为True,那么替换为input_embeds,否则替换为final_embeddings
		final_embeddings = torch.where(text_mask,input_embeds,final_embeddings)

		## 这里也是依据image_mask选择是True的位置进行替换,之所以不能用where是因为image_feature的形状与mask的形状不匹配,image_mask函数只会关注那些mask为True的位置
		final_embeddings = final_embeddings.masked_scatter(image_mask,scaled_image_features)

		## padding的替换是把所有padding位置的嵌入向量变成0向量
		final_embeddings = torch.where(padding_mask,torch.zeros_like(final_embeddings),final_embeddings)

至此,输入准备好了,但我们还需要准备模型后面会要用的注意力掩码和位置id。

注意力掩码可以根据kv cache中的信息来知道当前的token数量,从而构建注意力掩码矩阵,需要屏蔽的位置设置为负无穷,不需要屏蔽的位置为0。但此时,由于依次只有一个token推理,所以推理过程中注意力掩码无需屏蔽任何token,而在预加载用户的prompt时,论文中也没有屏蔽任何token,这是论文的做法,不同的人有不同的实现而已。

如图,这里在[sep]之前都没有屏蔽任何已有的token,sep之后由于单token推理,所以也无需屏蔽token。于是我们构建了全0的注意力掩码矩阵。

		if kv_cache is None or kv_cache.num_items() == 0: ##表明此时kv_cache是空的,说明是在预载入prompt阶段,此时根据论文可以不用屏蔽任何token,因为屏蔽了也没用,模型只会取最后一个token的embedding来预测下一个token,而最后一个embedding要求看到之前左右的token
			
			causal_mask = torch.full(
				(batch_size,q,q),
				fill_value= 0,
				dtype=dtype,
				device=device,
			)
		else: ##表明此时已经在推理过程中了,推理一次只会生成一个token,q == 1,而且推理过程中依据最新的这个token进行预测,所以也不需要屏蔽掉之前的token,需要屏蔽的置为负无穷,不需要屏蔽的置为0
			assert q == 1
			kv_len = kv_cache.num_items() + q
			causal_mask = torch.full(
				(batch_size,q,kv_len),
				fill_value= 0,
				dtype=dtype,
				device=device,
			)

最后我们根据attention mask来计算position id,因为这里我们没有做padding操作,因为我们一开始就限制了一次推理只有一个图像文本对,所以attention mask是全1的,所以可以通过对用前缀和的方式计算attention mask,得到每一个token的位置id。

由于推理阶段只有一个token会被传递后续进行attention计算 (这是因为kv cache已经记录了之前所有token的k向量和v向量,并且之前的token之间的qk计算是没有意义的,因为attention mask的存在,之前的一些token间的qk早在生成当前token之前已经做过了,所以无需将之前的token传递到后续阶段来计算attention)所以我们postion_ids最后一个作为当前tokenpostion_id传入进行旋转位置编码

每次只有产生那个tokenkv计算有效计算其余tokenkv计算重复计算

position_ids = position_ids[:,-1]

全量forward(即读取用户输入的prompt并计算kv cache)的过程由于此时kv cache所以用户prompt所有token一次性输入模型的,用来填充kv cache,所以所有postion_ids传入并且确保position_ids一个batch维度

注意全量forward过程我们没有形成类似于上三角矩阵这样注意力掩码因为这是论文做法仅此而已

随后我们返回 模型输入嵌入注意力掩码矩阵位置id位置id用于计算旋转位置编码

return final_embeddings,causal_mask,position_ids

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值