一、问题:Too many open files error
问题讨论:Too many open files error · Issue #11201 · pytorch/pytorch (github.com)
解决方法:
二、MaskFormer
maskformer
官方代码依赖detectron2
库,我没有安装成功,直接把源码下载下来,然后将里面的detectron2
包拷贝的MaskFormer的项目目录下,可以正常使用。
2.1 数据加载流程
数据加载器的构建类:detectron2/data/build.py
文件下的build_detection_train_loader
。
数据加载使用到的类所在文件路径:
MapDataset
:detectron2/data/common.pyMaskFormerSemanticDatasetMapper
:mask_former/data/dataset_mappers/mask_former_semantic_dataset_mapper.pyTrainingSampler
: detectron2/data/samplers/distributed_sampler.py
流程:
-
在train_net.py中,找到Trainer的父类
DefaultTrainer
,在父类中,会有构建模型、优化器和数据加载器的代码:# Assume these objects must be constructed in this order. model = self.build_model(cfg) optimizer = self.build_optimizer(cfg, model) data_loader = self.build_train_loader(cfg)
-
实际调用的
Trainer
中的build_train_loader
方法,首先会创建一个mapper; -
然后,将mapper作为参数,调用
build_detection_train_loader
; -
在
build_detection_train_loader
中,将根据_train_loader_from_config
来构建没有传递的参数,比如dataset、mapper、sampler,如果我们有自己的参数,可以在调用build_detection_train_loader
是直接传递进去。# 传递了mapper build_detection_train_loader(cfg, mapper=mapper)
-
后面就是的调用建立真正的数据加载器。
2.2 Mapper
Mapper是由MaskFormerSemanticDatasetMapper
类实现的,具体作用:将图像和标签,转化为Detetron2要求的格式。
maskformer在进行训练时需要:
data={
'image':图像,shape=(3,256,256),
'sem_seg_gt':标签,shape=(256,256),
'classes':当前图像中的所有类别,是一个一维的torch数组,shape=(4),
'masks':每一个类别的区域蒙版,shape=(4,256,256),
}
classes=np.unique(sem_seg_gt)
masks = []
for class_id in classes:
masks.append(sem_seg_gt == class_id)
Mapper就是对图像和标签进行了处理得到训练模型需要的数据及格式。
注意:由于这时的数据格式已经不是普通的torch数组或者numpy数据,而是复杂的json,我们需要重新自定义一个collator函数,告诉程序,如何将一个batch的数据堆叠起来。
这里实现比较简单,直接什么都不处理,也就是不需要每一个项都堆叠,直接以列表形式组织在一起就可以了。
def trivial_batch_collator(batch): """ A batch collator that does nothing. """ return batch
2.3 Sampler
maskformer默认的数据采样器是:TrainerSample,这是一个无限流的采样器,如果你使用for进行迭代获取数据,将会永不终止,不断地随机获取数据索引。
实际使用时,我进行了修改,不想要无限流这个特点。
def __iter__(self):
start = self._rank
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
def _infinite_indices(self):
g = torch.Generator()
g.manual_seed(self._seed)
while True:
if self._shuffle:
yield from torch.randperm(self._size, generator=g).tolist()
else:
yield from torch.arange(self._size).tolist()
# 改之后
def __iter__(self):
return iter(torch.randperm(self._size).tolist())
2.4 数据集格式
在这里,我并没有使用官方的数据集格式,我的数据集:
----images
-----001.jpg
-----002.jpg
-----003.jpg
----labels
-----001.png
-----002.png
-----003.png
然后,在构建数据加载器时,自然也没有使用官方的数据集类,而是将自己写的类直接作为参数传递给构建函数,这样后面程序看到dataset不是None,也就不会根据配置文件创建一个官方dataset类。
build_detection_train_loader(cfg, mapper=mapper,dataset=dataset)