华为mindspore-Bloom模块代码分析1.head/BloomLMHeadModel

bloom模型类似于gpt,适用于语言建模任务,可以用于生成文本、预测下一个词等任务。

bloom主体注意就是decoder加上了alibi位置编码,我们主要关注一下head的代码。head主要并供训练损失或logits(预测结果),是非常重要的一个部分。

class BloomHead(nn.Cell):
    """Head for Bloom to get the logits of each token in the vocab."""
    def __init__(self,
                 hidden_size,
                 vocab_size,
                 compute_type="float16",
                 parallel_config=None):
        super(BloomHead, self).__init__()
        copied_parallel_config = copy.deepcopy(parallel_config)
        mp = copied_parallel_config.model_parallel
        if vocab_size % mp != 0:
            logger.warning("The vocab size of BloomHead MatMul is: %s, it is not divide by model_parallel: %s",
                           vocab_size, mp)
            logger.warning("Now, the model_parallel num of BloomHead MatMul will be changed: mp = 1")
            copied_parallel_config.model_parallel = 1

        if copied_parallel_config.pipeline_stage > 1:
            copied_parallel_config.vocab_emb_dp = False
        if copied_parallel_config.vocab_emb_dp:
            self.matmul = P.MatMul(transpose_b=True).shard(((copied_parallel_config.data_parallel, 1), (1, 1)))
        else:
            self.matmul = P.MatMul(transpose_b=True).shard(((copied_parallel_config.data_parallel, 1), (
                copied_parallel_config.model_parallel, 1)))
        self.hidden_size = hidden_size
        self.dtype = convert_mstype(compute_type)

    def construct(self, state, embedding_table):
        ori_dtype = state.dtype
        state = state.reshape((-1, self.hidden_size))
        logits = self.matmul(state.astype(self.dtype), embedding_table.astype(self.dtype))
        return logits.astype(ori_dtype)

BloomHead 的类,继承自 nn.Cell。它是用于 Bloom 模型的头部,负责将模型的隐藏状态与嵌入表进行矩阵乘法运算得到logits。

构造函数 __init__ 接受一些参数,包括隐藏层大小 hidden_size、词汇表大小 vocab_size、计算类型 compute_type(默认为 "float16")和并行配置 parallel_config(默认为 None)。

在构造函数中,它根据并行配置来初始化并行模式,并处理模型并行不可整除词汇表规模的情况,将相应的警告信息记录下来。

然后,根据并行配置的设置,选择合适的操作符 MatMul,并对其进行划分和分片。

属性 hidden_size 记录隐藏层的大小,属性 dtype 记录计算类型。

construct 方法是模型的前向传播函数,接受两个张量作为输入:stateembedding_table。首先,将 state 重塑为2维张量,形状为 (-1, hidden_size)。然后,将 state 转换为指定的计算类型,并通过矩阵乘法操作 matmul 将其与 embedding_table 相乘得到logits。最后,将logits转换回原始的数据类型 ori_dtype 并返回。

@MindFormerRegister.register(MindFormerModuleType.MODELS)
class BloomLMHeadModel(BaseModel):
    """
    Provide bloom training loss or logits through network.

    Args:
        config (BloomConfig): The config of BloomModel.

    Returns:
        Tensor, the loss or logits of the network.
    """

    _support_list = MindFormerBook.get_model_support_list()['bloom']

    def __init__(self, config=None):
        config = config if config is not None else BloomConfig()
        super(BloomLMHeadModel, self).__init__(config, auto_prefix=False)

        self.eos_token_id = self.config.eos_token_id
        parallel_config = self.config.parallel_config
        self.stridedslice = P.StridedSlice().shard(((parallel_config.data_parallel, 1),))
        self.not_equal = P.NotEqual().shard(((parallel_config.data_parallel, 1), ()))

        self.transformer = BloomModel(self.config)
        self.head = BloomHead(hidden_size=config.hidden_size,
                              vocab_size=config.vocab_size,
                              parallel_config=self.config.parallel_config)
        if parallel_config.pipeline_stage > 1:
            self.head.pipeline_stage = parallel_config.pipeline_stage - 1
            self.transformer.embedding.word_embedding.embedding_table.add_pipeline_stage(self.head.pipeline_stage)

        mp = config.parallel_config.model_parallel
        vocab_size = config.vocab_size
        loss_parallel_config = copy.deepcopy(parallel_config)
        if vocab_size % mp != 0:
            logger.warning("The vocab size of Bloom Loss is: %s, it is not divide by model_parallel: %s",
                           vocab_size, mp)
            logger.warning("Now, the model_parallel num of Bloom Loss will be changed: mp = 1")
            loss_parallel_config.model_parallel = 1

        self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config)
        self.load_checkpoint(config)

    def construct(self, input_ids):
        """
        construct function for Language Modeling

        Args:
            input_ids (Tensor): the indices of input sequence tokens in the vocabulary.

        Returns:
            logits (Tensor) or loss (mstype.float32): if is_training is False, directly return the logits,
                                                      otherwise, return the computed loss.
        """

        batch_size, seq_length = input_ids.shape

        if self.phase == "train":
            tokens = self.stridedslice(input_ids, (0, 0), (batch_size, seq_length - 1), (1, 1))
        else:
            tokens = input_ids

        input_mask = self.not_equal(tokens, self.eos_token_id).astype(mstype.float32)

        # [batch_size, seq_length, vocab_size]
        output_states, embedding_table = self.transformer(tokens, input_mask)
        logits = self.head(output_states, embedding_table)

        if self.phase != 'train':
            return logits, tokens, input_mask

        labels = self.stridedslice(input_ids, (0, 1), (batch_size, seq_length), (1, 1))
        labels = labels.reshape((-1,))
        input_mask = input_mask.reshape((-1,))
        loss = self.loss(logits, labels, input_mask)
        return loss

这段代码定义了一个名为 BloomLMHeadModel 的类,继承自 BaseModel

构造函数 __init__ 接受一个配置对象 config,如果未提供则使用默认配置 BloomConfig()。在构造函数中,它初始化了一些子模块,包括 transformerhead,以及一些计算所需的操作。

其中,construct 方法是模型的前向传播函数,接受一个名为 input_ids 的张量作为输入。在训练阶段,它会对输入序列进行切片,去除最后一个token,并生成相应的mask。

然后,将切片后的tokens输入到 transformer 模块中进行计算,得到输出状态和嵌入表。

接下来,调用 head 模块将输出状态和嵌入表映射为logits。

如果当前是训练阶段,还需要计算损失值。首先从 input_ids 中提取标签部分,并将其展平为一维张量。然后,计算损失值并返回。

如果当前不是训练阶段,直接返回logits。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值