bloom模型类似于gpt,适用于语言建模任务,可以用于生成文本、预测下一个词等任务。
bloom主体注意就是decoder加上了alibi位置编码,我们主要关注一下head的代码。head主要并供训练损失或logits(预测结果),是非常重要的一个部分。
class BloomHead(nn.Cell):
"""Head for Bloom to get the logits of each token in the vocab."""
def __init__(self,
hidden_size,
vocab_size,
compute_type="float16",
parallel_config=None):
super(BloomHead, self).__init__()
copied_parallel_config = copy.deepcopy(parallel_config)
mp = copied_parallel_config.model_parallel
if vocab_size % mp != 0:
logger.warning("The vocab size of BloomHead MatMul is: %s, it is not divide by model_parallel: %s",
vocab_size, mp)
logger.warning("Now, the model_parallel num of BloomHead MatMul will be changed: mp = 1")
copied_parallel_config.model_parallel = 1
if copied_parallel_config.pipeline_stage > 1:
copied_parallel_config.vocab_emb_dp = False
if copied_parallel_config.vocab_emb_dp:
self.matmul = P.MatMul(transpose_b=True).shard(((copied_parallel_config.data_parallel, 1), (1, 1)))
else:
self.matmul = P.MatMul(transpose_b=True).shard(((copied_parallel_config.data_parallel, 1), (
copied_parallel_config.model_parallel, 1)))
self.hidden_size = hidden_size
self.dtype = convert_mstype(compute_type)
def construct(self, state, embedding_table):
ori_dtype = state.dtype
state = state.reshape((-1, self.hidden_size))
logits = self.matmul(state.astype(self.dtype), embedding_table.astype(self.dtype))
return logits.astype(ori_dtype)
BloomHead
的类,继承自 nn.Cell
。它是用于 Bloom 模型的头部,负责将模型的隐藏状态与嵌入表进行矩阵乘法运算得到logits。
构造函数 __init__
接受一些参数,包括隐藏层大小 hidden_size
、词汇表大小 vocab_size
、计算类型 compute_type
(默认为 "float16")和并行配置 parallel_config
(默认为 None
)。
在构造函数中,它根据并行配置来初始化并行模式,并处理模型并行不可整除词汇表规模的情况,将相应的警告信息记录下来。
然后,根据并行配置的设置,选择合适的操作符 MatMul
,并对其进行划分和分片。
属性 hidden_size
记录隐藏层的大小,属性 dtype
记录计算类型。
construct
方法是模型的前向传播函数,接受两个张量作为输入:state
和 embedding_table
。首先,将 state
重塑为2维张量,形状为 (-1, hidden_size)。然后,将 state
转换为指定的计算类型,并通过矩阵乘法操作 matmul
将其与 embedding_table
相乘得到logits。最后,将logits转换回原始的数据类型 ori_dtype
并返回。
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class BloomLMHeadModel(BaseModel):
"""
Provide bloom training loss or logits through network.
Args:
config (BloomConfig): The config of BloomModel.
Returns:
Tensor, the loss or logits of the network.
"""
_support_list = MindFormerBook.get_model_support_list()['bloom']
def __init__(self, config=None):
config = config if config is not None else BloomConfig()
super(BloomLMHeadModel, self).__init__(config, auto_prefix=False)
self.eos_token_id = self.config.eos_token_id
parallel_config = self.config.parallel_config
self.stridedslice = P.StridedSlice().shard(((parallel_config.data_parallel, 1),))
self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ()))
self.transformer = BloomModel(self.config)
self.head = BloomHead(hidden_size=config.hidden_size,
vocab_size=config.vocab_size,
parallel_config=self.config.parallel_config)
if parallel_config.pipeline_stage > 1:
self.head.pipeline_stage = parallel_config.pipeline_stage - 1
self.transformer.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
mp = config.parallel_config.model_parallel
vocab_size = config.vocab_size
loss_parallel_config = copy.deepcopy(parallel_config)
if vocab_size % mp != 0:
logger.warning("The vocab size of Bloom Loss is: %s, it is not divide by model_parallel: %s",
vocab_size, mp)
logger.warning("Now, the model_parallel num of Bloom Loss will be changed: mp = 1")
loss_parallel_config.model_parallel = 1
self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config)
self.load_checkpoint(config)
def construct(self, input_ids):
"""
construct function for Language Modeling
Args:
input_ids (Tensor): the indices of input sequence tokens in the vocabulary.
Returns:
logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
otherwise, return the computed loss.
"""
batch_size, seq_length = input_ids.shape
if self.phase == "train":
tokens = self.stridedslice(input_ids, (0, 0), (batch_size, seq_length - 1), (1, 1))
else:
tokens = input_ids
input_mask = self.not_equal(tokens, self.eos_token_id).astype(mstype.float32)
# [batch_size, seq_length, vocab_size]
output_states, embedding_table = self.transformer(tokens, input_mask)
logits = self.head(output_states, embedding_table)
if self.phase != 'train':
return logits, tokens, input_mask
labels = self.stridedslice(input_ids, (0, 1), (batch_size, seq_length), (1, 1))
labels = labels.reshape((-1,))
input_mask = input_mask.reshape((-1,))
loss = self.loss(logits, labels, input_mask)
return loss
这段代码定义了一个名为 BloomLMHeadModel
的类,继承自 BaseModel
。
构造函数 __init__
接受一个配置对象 config
,如果未提供则使用默认配置 BloomConfig()
。在构造函数中,它初始化了一些子模块,包括 transformer
和 head
,以及一些计算所需的操作。
其中,construct
方法是模型的前向传播函数,接受一个名为 input_ids
的张量作为输入。在训练阶段,它会对输入序列进行切片,去除最后一个token,并生成相应的mask。
然后,将切片后的tokens输入到 transformer
模块中进行计算,得到输出状态和嵌入表。
接下来,调用 head
模块将输出状态和嵌入表映射为logits。
如果当前是训练阶段,还需要计算损失值。首先从 input_ids
中提取标签部分,并将其展平为一维张量。然后,计算损失值并返回。
如果当前不是训练阶段,直接返回logits。