trainer使用 torch.utils.data 的 Dataset

Dataset适配

Transformers实战——使用Trainer类训练和评估自己的数据和模型 中,使用的数据集类别是from datasets import Dataset,但在常用的实现中,pytorch项目会使用torch.utils.data 中的数据集。

为了让pytorch中的Dataset类适配transformers中的Trainer ,需要对类别中的def __getitem__(self, idx) 方法进行修改,如下:

    def __getitem__(self, idx):
		    # 具体实现,可以自由发挥
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            return_token_type_ids=False,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        
        # 下面的返回要改成字典形式
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

简单来说就是把返回改成字典形式,且字典的键(key)需要与模型的forward对应。

DataCollator

上面的改动可以控制每一个样本的返回值,但模型训练时的样本往往是一批一批的。如果每个样本的输入形状(shape)固定不变的话,到上一步模型就能正常运行了。

有些时候每个样本的输入形状不一样,如文本分类(每篇文章字数不同),就会导致训练不能正常进行。

这种情况一般有两种方法(以文本分类为例):

  1. 修改__getitem__ :添加一个最大长度 max_length 把每个文本扩展(0填充)或截取到固定长度(如512)。
  2. 使用transformes中的DataCollator 类,如文本填充常用的DataCollatorWithPadding 。这种方式不用修改Dataset类的源码。

使用DataCollator 的样例如下:

from transformers import DataCollatorWithPadding

# 这里的tokenizer变量可以传入BertTokenizer类
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors='pt')

    trainer = Trainer(
        model=model,  # the instantiated 🤗 Transformers model to be trained 需要训练的模型
        args=training_args,  # training arguments, defined above 训练参数
        train_dataset=train_dataset,  # training dataset 训练集
        eval_dataset=dev_dataset,  # evaluation dataset 测试集
        compute_metrics=compute_metrics,  # 计算指标方法
        ########
        data_collator=data_collator, # 传入DataCollator 
        ########
    )

这种方法可以动态地根据每一批次的最大文本长度进行补全,可以一定程度节省内存消耗。

Freezing layer 'model.22.dfl.conv.weight' AMP: running Automatic Mixed Precision (AMP) checks with YOLOv8n... Traceback (most recent call last): File "C:\Users\14480\Desktop\毕设\ultralytics\ultralytics-main\train.py", line 12, in <module> model.train(data="C://Users//14480//Desktop//毕设//ultralytics//ultralytics-main//dataset//data.yaml", File "C:\Users\14480\Desktop\毕设\ultralytics\ultralytics-main\ultralytics\engine\model.py", line 650, in train self.trainer.train() File "C:\Users\14480\Desktop\毕设\ultralytics\ultralytics-main\ultralytics\engine\trainer.py", line 205, in train self._do_train(world_size) File "C:\Users\14480\Desktop\毕设\ultralytics\ultralytics-main\ultralytics\engine\trainer.py", line 324, in _do_train self._setup_train(world_size) File "C:\Users\14480\Desktop\毕设\ultralytics\ultralytics-main\ultralytics\engine\trainer.py", line 263, in _setup_train self.amp = torch.tensor(check_amp(self.model), device=self.device) File "C:\Users\14480\Desktop\毕设\ultralytics\ultralytics-main\ultralytics\utils\checks.py", line 663, in check_amp assert amp_allclose(YOLO("yolov8n.pt"), im) File "C:\Users\14480\Desktop\毕设\ultralytics\ultralytics-main\ultralytics\models\yolo\model.py", line 23, in __init__ super().__init__(model=model, task=task, verbose=verbose) File "C:\Users\14480\Desktop\毕设\ultralytics\ultralytics-main\ultralytics\engine\model.py", line 149, in __init__ self._load(model, task=task) File "C:\Users\14480\Desktop\毕设\ultralytics\ultralytics-main\ultralytics\engine\model.py", line 230, in _load self.model, self.ckpt = attempt_load_one_weight(weights) File "C:\Users\14480\Desktop\毕设\ultralytics\ultralytics-main\ultralytics\nn\tasks.py", line 948, in attempt_load_one_weight ckpt, weight = torch_safe_load(weight) # load ckpt File "C:\Users\14480\Desktop\毕设\ultralytics\ultralytics-main\ultralytics\nn\tasks.py", line 874, in torch_safe_load ckpt = torch.load(file, map_location="cpu") File "E:\anaconda\envs\cycy\lib\site-packages\torch\serialization.py", line 1470, in load raise pickle.UnpicklingError(_get_wo_message(str(e))) from None _pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint. (1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source. (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message. WeightsUnpickler error: Unsupported global: GLOBAL ultralytics.nn.tasks.DetectionModel was not an allowed global by default. Please use `torch.serialization.add_safe_globals([DetectionModel])` or the `torch.serialization.safe_globals([DetectionModel])` context manager to allowlist this global if you trust this class/function. Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
03-15
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值