发现问题
问题简述
最近在做一个项目。项目里,每张待提特征的图都有自己的坐标。想把坐标和图一起通过torch的DataLoader
类打包成一个个batch,却发现输出的每个batch中,坐标和输入不一样。
具体情况
每张图的坐标由一个列表表示:[左上x,左上y,宽度,高度]
。先将其处理成中心坐标,即[左上x+1/2宽度,左上y+1/2高度]
。然后跟随图片打包成batch。
举个例子
如果输入为[(0,0,100,100), (100,100,100,100), (200,200,100,100)]
,batch_size>=3
那么第一个batch的坐标输出应该是[(50.0, 50.0),(150.0, 150.0)]
但是实际输出为:(50.0, 150.0),(50.0, 150.0)
分析问题
data_loader是打包成batch的数据
data_loader = DataLoader(dataset, batch_size = 32, num_workers = 0)
遍历data_loader的每一个batch,即
for batch_idx, c_points in enumerate(data_loader):
这里的c_point
就是被打包后的数据,和原数据不一样,因此接下来需要研究这个for循环。它调用了DataLoader类中的__iter__()
方法
def __iter__(self):
return _DataLoaderIter(self)
因此,for循环的每次迭代,都会调用_DataLoaderIter类的__next__()
方法。为了将重点放在本次遇到的问题上,设num_workers=0,pin_memory=False。因此在__next__()
方法中,实际执行的代码为:
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
return batch
下面分析这两句话
第一句
首先要知道self.sample_iter
是个什么东西。在当前类的初始化部分有:
self.sample_iter = iter(self.batch_sampler)
而self.batch_sampler
在DataLoader类中可以找到实现:
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
这里用到的是SequentialSampler类。另一个是训练用,只是加入了随机特性而已。因此实际执行的只有两行代码
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
对于第一行,在SequentialSampler类中,可以看到
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
它返回的是一个有关输入数据索引的迭代器,因此sampler
是一个关于输入数据索引的迭代器。
对于第二行,找到BatchSampler类的__iter__()方法
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
在这里就不得不提一下yield这个东西。它和return很像,可以返回值,但是程序并不会跳出,而是继续进行,因此它可以实现迭代器的功能,每次返回一个batch。在这个batch中,存的是一组索引。比如有100个数据,batch_size=3,那么在第二次迭代时,batch=[3,4,5]。
因此,针对第一句话可以解释如下:
indices = next(self.sample_iter) # may raise StopIteration
self.sample_iter = iter(self.batch_sampler)
是一个迭代器,它遍历每个batch对应的索引。indices = next(self.sample_iter)
就是在执行迭代步骤,indices就是当前batch对应的索引。
第二句
明白第一句的作用后,再看第二句。先将其列如下
batch = self.collate_fn([self.dataset[i] for i in indices])
其中[self.dataset[i] for i in indices]
很好理解,就是输入数据的某一个batch。但是这个self.collate_fn是什么鬼?马上定位到其实现部分
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
if isinstance(batch[0], torch.Tensor):
out = None
if _use_shared_memory:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype))
return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int_classes):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], container_abcs.Sequence):
print('container_abcs.Sequence')
transposed = zip(*batch)
for samples in transposed:
print(samples)
return [default_collate(samples) for samples in transposed]
raise TypeError((error_msg.format(type(batch[0]))))
可以发现,入参名为batch,它先判断batch[0]的数据类型。对于本文的问题,它是一个列表,属于container_abcs.Sequence,即执行以下代码
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
注意到这里有一个zip方法,而且在入参前面还加了一个*
。对于*
的作用,网上解释如下:在函数调用多个参数时,在列表、元组、集合、字典及其他可迭代对象前加*
,可以使其自动解包并传递给相应单个变量,其中参数个数个元素个数要相等。
举个例子
def foo(a,b,c):
print(a,b,c)
foo(*[1,2,3])
结果
1 2 3
在zip(*batch)
中,因为有*
,所以batch中的每一个元素作为zip()
的第1、2、3…个输入参数;如果不加*
,那么整个batch
作为zip()
的唯一输入参数。区别如下
batch = [[1,2], [3,4]]
transposed = zip(*batch)
for samples in transposed:
print(samples)
transposed = zip(batch)
for samples in transposed:
print(samples)
结果
(1, 3)
(2, 4)
([1, 2],)
([3, 4],)
这便是问题的根源所在。
这里插一脚,zip()
返回的东西可以看成迭代器,迭代完了就完了,如下
batch = [[1,2], [3,4]]
transposed = zip(*batch)
for samples in transposed:
print(samples)
for samples in transposed:
print(samples)
结果
(1, 3)
(2, 4)
下面继续分析问题,注意到在zip()
函数后还有一句
return [default_collate(samples) for samples in transposed]
即把zip()
后的结果再次逐元素地送回原来的default_collate()函数打包。
假设batch=[[1,1],[2,2],[3,3]],transposed = zip(*batch)
的遍历结果为(1,2,3),(1,2,3),那么其中每个元素samples=(1,2,3)将再次被打包。由于其第0位的数据类型为int,所以进入
elif isinstance(batch[0], int_classes):
return torch.LongTensor(batch)
即将其转为tensor类型。这些被转为tensor类型的元素被放入一个列表中返回,如果沿用前面的例子,那么返回结果为[tensor([ 1., 2., 3.], dtype=torch.float64), tensor([ 1., 2.,3.], dtype=torch.float64)]
。
也就是第二句中batch
的值。
batch = self.collate_fn([self.dataset[i] for i in indices])
同样也是最初for循环中c_point
的值。
for batch_idx, c_points in enumerate(data_loader):