HuggingFace 自定义数据集,使用 Trainer 训练,输入增加多个维度信息

需求描述

  1. HuggingFace (transformers) 自定义图像数据集、使用 DeiT 模型、Trainer 进行训练回归任务 文章中的 3.Dataset 部分曾提到过:当使用 HuggingFace 提供的 Trainer 进行训练时,对于自定义数据集,一定要满足:
    __getitem__ 方法的返回值形式一定要是 {"labels": xxx, "pixel_values": xxx}
  2. 但如果我希望输入不仅仅是图像的 tensor 表示,还希望增加另一个维度的输入,比如 camera_position 的 tensor 表示。也就是我希望自定义数据集的 __getitem__ 方法的返回值如下所示:
{
	// 将 pixel_values 和 camera_position 作为多维输入
	'pixel_values': 图像转为 tensor 变量,
	'camera_positon': np.array([0.0, 1.0, 0.0]), 
	'labels': 标签
}
  1. 但是我发现,在训练过程中会报错:找不到 camera_position,这是因为HuggingFace 会将数据集进行处理,转化为只有 ‘pixel_values’ 和 ‘labels’ 两个属性。

需求实现

如果想保留 ‘camera_position’,则需要修改 transformers 的源码:

  • transformers.trainer_utils.py 中的代码 features = [self._remove_columns(feature) for feature in features] 删除。

transformers.trainer_utils.py 的部分代码如下:

在我的计算机中,transformers.trainer_utils.py 文件的全路径为 D:\Anaconda3\envs\transformers\Lib\site-packages\transformers
其中 transformers 为使用 conda create -n transfoemers 创建的环境名称。

class RemoveColumnsCollator:
    """Wrap the data collator to remove unused columns before they are passed to the collator."""

    def __init__(
        self,
        data_collator,
        signature_columns,
        logger=None,
        model_name: Optional[str] = None,
        description: Optional[str] = None,
    ):
        self.data_collator = data_collator
        self.signature_columns = signature_columns
        self.logger = logger
        self.description = description
        self.model_name = model_name
        self.message_logged = False

    def _remove_columns(self, feature: dict) -> dict:
        if not isinstance(feature, dict):
            return feature
        if not self.message_logged and self.logger and self.model_name:
            ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
            if len(ignored_columns) > 0:
                dset_description = "" if self.description is None else f"in the {self.description} set"
                self.logger.info(
                    f"The following columns {dset_description} don't have a corresponding argument in "
                    f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
                f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
                " you can safely ignore this message."
                )
                self.message_logged = True
        return {k: v for k, v in feature.items() if k in self.signature_columns}

    def __call__(self, features: List[dict]):
        # 修改如下:将这一行注释掉,不要删除多余的属性
        # features = [self._remove_columns(feature) for feature in features]
        return self.data_collator(features)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

悄悄地努力

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值