需求描述
- 在 HuggingFace (transformers) 自定义图像数据集、使用 DeiT 模型、Trainer 进行训练回归任务 文章中的
3.Dataset
部分曾提到过:当使用 HuggingFace 提供的 Trainer 进行训练时,对于自定义数据集,一定要满足:
__getitem__
方法的返回值形式一定要是{"labels": xxx, "pixel_values": xxx}
。 - 但如果我希望输入不仅仅是图像的 tensor 表示,还希望增加另一个维度的输入,比如 camera_position 的 tensor 表示。也就是我希望自定义数据集的
__getitem__
方法的返回值如下所示:
{
// 将 pixel_values 和 camera_position 作为多维输入
'pixel_values': 图像转为 tensor 变量,
'camera_positon': np.array([0.0, 1.0, 0.0]),
'labels': 标签
}
- 但是我发现,在训练过程中会报错:找不到 camera_position,这是因为HuggingFace 会将数据集进行处理,转化为只有 ‘pixel_values’ 和 ‘labels’ 两个属性。
需求实现
如果想保留 ‘camera_position’,则需要修改 transformers 的源码:
- 将
transformers.trainer_utils.py
中的代码features = [self._remove_columns(feature) for feature in features]
删除。
transformers.trainer_utils.py 的部分代码如下:
在我的计算机中,transformers.trainer_utils.py 文件的全路径为
D:\Anaconda3\envs\transformers\Lib\site-packages\transformers
其中 transformers 为使用 conda create -n transfoemers 创建的环境名称。
class RemoveColumnsCollator:
"""Wrap the data collator to remove unused columns before they are passed to the collator."""
def __init__(
self,
data_collator,
signature_columns,
logger=None,
model_name: Optional[str] = None,
description: Optional[str] = None,
):
self.data_collator = data_collator
self.signature_columns = signature_columns
self.logger = logger
self.description = description
self.model_name = model_name
self.message_logged = False
def _remove_columns(self, feature: dict) -> dict:
if not isinstance(feature, dict):
return feature
if not self.message_logged and self.logger and self.model_name:
ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
if len(ignored_columns) > 0:
dset_description = "" if self.description is None else f"in the {self.description} set"
self.logger.info(
f"The following columns {dset_description} don't have a corresponding argument in "
f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
" you can safely ignore this message."
)
self.message_logged = True
return {k: v for k, v in feature.items() if k in self.signature_columns}
def __call__(self, features: List[dict]):
# 修改如下:将这一行注释掉,不要删除多余的属性
# features = [self._remove_columns(feature) for feature in features]
return self.data_collator(features)