多模态学习笔记 - 语言模型篇(3)
参考repo:WatchTower-Liu/VLM-learning; url: VLLM-BASE
吐槽
今天接着昨天的源码继续看,黑神话:悟空正好今天发售,希望广大coder能玩的开心~
学习心得
前情提要
详情请看多模态源码阅读 - 2
上次我们讲到利用view()函数对token_type_ids、position_ids进行重新塑形,确保这些张量的最后一个维度和input_shape(输入序列数据)的最后一个维度相等。重构的代码中默认启用缓存键值对(显然use_cache的bool值有点可有可无了QAQ),如果past_key_values的值为空,代表处于推理或者训练的第一步,此时我们初始化past_length为0,初始化past_key_values为长度为Qwen模型层数量的元组,self.h是Qwen模型的成员变量,我们无需太过关心(因为我们只是继承Qwen模型的成员变量,并重构了forward方法)。
如果我们当前不处于训练或推理的第一步,past_key_values显然就不为空(因为我们默认启用缓存键值对,ps:科研级代码是这样的),不管缓存量化(use_cache_quantization)启用与否,我们将past_length更新为第一个注意力头键张量的第二个或倒数第二个维度。这里唯一的区别只是元组的维数和维度不太一样。
如果position_ids为None,我们需要初始化一个position_ids,起始位置为past_length,终止位置为psst_lenght + input_shape[=1],确保我们的position_ids长度与input_shape的最后一个维度相等,随后重新塑形,同样是为了确保position_ids为二维张量,且最后一个维度与input_shape对齐,代码如下:
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
if self.use_cache_quantization:
past_length = past_key_values[0][0][0].size(2)
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
新的记忆
代码块1
接着上面的代码,继续看MQwen.py中MQwenModel中重构的forward方法,代码如下:
if attention_mask is not None:
# image_feaute_length = self.otherConfig["image_context_length"]*self.otherConfig["image_feature_hidden_size"]
# attention_mask_length = attention_mask.shape[-1] - image_feaute_length + self.otherConfig["image_context_length"]
# attention_mask = torch.ones((batch_size, attention_mask_length), dtype=torch.long, device=device)
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=self.dtype)
attention_mask