上一篇文章中介绍了pipeline的主要组成,最后两个流程有点难理解,这两个流程的作用就是把数据格式化成为训练需要的格式,都定义在’/mmdet/datasets/pipelines/formating.py’
一、DefaultFormatBundle
这一步就是把img、bboxes、labels等转换为tensor,再转换为DataContainer。打包运算DataContainer对象, 方便程序后序训练过程中读取。
@PIPELINES.register_module()
class DefaultFormatBundle:
"""默认成批格式化数据
简化了公共字段的格式化流程
"img","proposals", "gt_bboxes", "gt_labels", "gt_masks" and"gt_semantic_seg".
就是将result中下面的field按照下面顺序打包
- img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
- proposals: (1)to tensor, (2)to DataContainer
- gt_bboxes: (1)to tensor, (2)to DataContainer
- gt_bboxes_ignore: (1)to tensor, (2)to DataContainer
- gt_labels: (1)to tensor, (2)to DataContainer
- gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True)
- gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, (3)to DataContainer (stack=True)
"""
def __init__(self,
img_to_float=True,
pad_val=dict(img=0, masks=0, seg=255)):
self.img_to_float = img_to_float
self.pad_val = pad_val
def __call__(self, results):
if 'img' in results:
img = results['img']
if self.img_to_float is True and img.dtype == np.uint8:
# Normally, image is of uint8 type without normalization.
# At this time, it needs to be forced to be converted to
# flot32, otherwise the model training and inference
# will be wrong. Only used for YOLOX currently .
img = img.astype(np.float32)
# add default meta keys
results = self._add_default_meta_keys(results)
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
results['img'] = DC(
to_tensor(img), padding_value=self.pad_val['img'], stack=True)
for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']:
if key not in results:
continue
results[key] = DC(to_tensor(results[key]))
if 'gt_masks' in results:
results['gt_masks'] = DC(
results['gt_masks'],
padding_value=self.pad_val['masks'],
cpu_only=True)
if 'gt_semantic_seg' in results:
results['gt_semantic_seg'] = DC(
to_tensor(results['gt_semantic_seg'][None, ...]),
padding_value=self.pad_val['seg'],
stack=True)
return results
pytorch中collate打包函数是将一个batch数据在某个维度上进行堆叠,这种方式有以下缺点:
1、所有tensor必须有相同尺寸
2、类型限制比较死,仅有tensor或者numpy。
为了更能灵活处理打包tensor,设计了DataContainer,但pytorch无法打包运算DataContainer对象。
故又设计了MMDataParallel来处理运算DC中内容。
class DataContainer:
def __init__(self,
data,
stack=False,
padding_value=0,
cpu_only=False,
pad_dims=2):
self._data = data
self._cpu_only = cpu_only
self._stack = stack
self._padding_value = padding_value
assert pad_dims in [None, 1, 2, 3]
self._pad_dims = pad_dims
def __repr__(self):
return f'{self.__class__.__name__}({repr(self.data)})'
def __len__(self):
return len(self._data)
@property
def data(self):
return self._data
@property
def datatype(self):
if isinstance(self.data, torch.Tensor):
return self.data.type()
else:
return type(self.data)
@property
def cpu_only(self):
return self._cpu_only
@property
def stack(self):
return self._stack
@property
def padding_value(self):
return self._padding_value
@property
def pad_dims(self):
return self._pad_dims
@assert_tensor_type
def size(self, *args, **kwargs):
return self.data.size(*args, **kwargs)
@assert_tensor_type
def dim(self):
return self.data.dim()
二、Collect
Collect data 翻译过来就是收集数据,添加了一个新的key:‘img_meta’,
把result中的这些key,‘filename’, ‘ori_filename’, ‘ori_shape’,‘img_shape’, ‘pad_shape’, ‘scale_factor’, ‘flip’,‘flip_direction’, 'img_norm_cfg’等包装成一个字典,插入到results[‘img_metas’]
@PIPELINES.register_module()
class Collect:
def __init__(self,
keys,
meta_keys=('filename', 'ori_filename', 'ori_shape',
'img_shape', 'pad_shape', 'scale_factor', 'flip',
'flip_direction', 'img_norm_cfg')):
self.keys = keys
self.meta_keys = meta_keys
def __call__(self, results):
data = {}
img_meta = {}
for key in self.meta_keys:
img_meta[key] = results[key]
data['img_metas'] = DC(img_meta, cpu_only=True)
for key in self.keys:
data[key] = results[key]
return data
def __repr__(self):
return self.__class__.__name__ + \
f'(keys={self.keys}, meta_keys={self.meta_keys})'
会返回一个具有四个key的字典{‘img’: …, ‘gt_bboxes’: …, ‘gt_labels’: …, meta_keys: …}
Collect(keys=['img', 'gt_bboxes', 'gt_labels'], meta_keys=('filename', 'ori_filename', 'ori_shape', 'img_shape', 'pad_shape', 'scale_factor', 'flip', 'flip_direction', 'img_norm_cfg'))