批量数据的数据整理器的编写(重写)

data_collator库中定义了一个DynamicDataCollatorWithPadding类,这是一个专门用于处理批量数据的数据整理器(collator),它继承自transformers库中的DataCollatorWithPadding类。下面是对这个类的编写思路、编写目的和作用的详细分析:

编写思路

  1. 继承现有功能:通过继承DataCollatorWithPadding,利用其已有的填充功能。

    class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
        r"""
        Inherits DataCollatorWithPadding. It is capable of dynamically padding for batched data.
        """
        def __init__(
                self,
                tokenizer: PreTrainedTokenizer,
                ignore_pad_token_for_loss: Optional[bool] = False
        ):
            super().__init__(tokenizer, padding=True)
            self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id

  2. 定制化修改:在初始化函数中,根据是否忽略填充标记(pad token)对损失计算的影响来设置label_pad_token_id

  3. 批量处理优化:定义get_attention_masks函数来动态生成注意力掩码(attention masks),处理左填充(left-padding)的序列。

    def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
        r"""
        Generates attention masks for left-padded sequences.
        """
        batch_size, seq_length = input_ids.size()
        attention_mask = torch.ones((batch_size, seq_length), device=device)
    ​
        for i, seq in enumerate(input_ids):
            attention_mask[i, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding
    ​
        attention_mask = attention_mask.bool()
        return attention_mask

  4. 调用函数定制:重写__call__方法以支持批量数据的左填充,这在某些特定模型训练中是必需的。

    def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> BatchEncoding:
        r"""
        Pads batched data to the longest sequence in the batch.
    ​
        We adopt left-padding in both training and evaluation.
        """
        if isinstance(features[0]["input_ids"], torch.Tensor):
            input_ids = [feature["input_ids"].clone().detach().flip(0) for feature in features]
        else:
            input_ids = [torch.tensor(feature["input_ids"]).flip(0) for feature in features]
    ​
        if "labels" in features[0]:
            if isinstance(features[0]["labels"], torch.Tensor):
                labels = [feature["labels"].clone().detach().flip(0) for feature in features]
            else:
                labels = [torch.tensor(feature["labels"]).flip(0) for feature in features]
            input_ids = input_ids + labels # pad them to the same length
    ​
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id
        ).flip(-1)
    ​
        batch = {}
    ​
        if "labels" in features[0]:
            input_ids, labels = input_ids.split(len(features), dim=0)
            labels = torch.where(labels != self.tokenizer.pad_token_id, labels, self.label_pad_token_id)
            batch["labels"] = labels
    ​
        batch["input_ids"] = input_ids
        batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
    ​
        return BatchEncoding(batch)

编写目的

  • 支持动态填充:适应不同长度的输入序列,动态地为较短的序列添加填充,确保所有序列长度一致,以便进行批量处理。

  • 优化模型输入:通过生成和管理有效的注意力掩码,确保模型在处理填充序列时能够正确地忽略填充部分。

  • 提升训练效率:通过高效地处理批量数据输入,提升模型训练的效率和数据处理的效率。

作用

  • 数据整理DynamicDataCollatorWithPadding为批量数据提供了一个高效的整理机制,确保所有数据批次在送入模型之前具有统一的形状和结构。

  • 注意力掩码生成:自动为每个批次生成正确的注意力掩码,这对于基于注意力机制的模型来说是必须的,以确保模型正确理解和处理输入数据中的实际内容与填充部分。

  • 支持自定义填充:后续编写可以根据是否需要在损失计算中考虑填充标记来定制数据整理器的行为。

data_collator库中的DynamicDataCollatorWithPadding类为基于transformers的NLP模型训练提供了重要的数据预处理支持,特别是在处理不等长序列数据时提供了必要的功能,确保数据输入的一致性和高效性。

  • 18
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值