关于torch加载list类型数据的思考

发现问题

问题简述

最近在做一个项目。项目里,每张待提特征的图都有自己的坐标。想把坐标和图一起通过torch的DataLoader类打包成一个个batch,却发现输出的每个batch中,坐标和输入不一样。

具体情况

每张图的坐标由一个列表表示:[左上x,左上y,宽度,高度]。先将其处理成中心坐标,即[左上x+1/2宽度,左上y+1/2高度]。然后跟随图片打包成batch。

举个例子

如果输入为[(0,0,100,100), (100,100,100,100), (200,200,100,100)]batch_size>=3
那么第一个batch的坐标输出应该是[(50.0, 50.0),(150.0, 150.0)]
但是实际输出为:(50.0, 150.0),(50.0, 150.0)

分析问题

data_loader是打包成batch的数据

data_loader = DataLoader(dataset, batch_size = 32, num_workers = 0)

遍历data_loader的每一个batch,即

for batch_idx, c_points in enumerate(data_loader):

这里的c_point就是被打包后的数据,和原数据不一样,因此接下来需要研究这个for循环。它调用了DataLoader类中的__iter__()方法

def __iter__(self):
    return _DataLoaderIter(self)

因此,for循环的每次迭代,都会调用_DataLoaderIter类的__next__()方法。为了将重点放在本次遇到的问题上,设num_workers=0,pin_memory=False。因此在__next__()方法中,实际执行的代码为:

indices = next(self.sample_iter)  # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
return batch

下面分析这两句话

第一句

首先要知道self.sample_iter是个什么东西。在当前类的初始化部分有:

self.sample_iter = iter(self.batch_sampler)

self.batch_sampler在DataLoader类中可以找到实现:

if batch_sampler is None:
    if sampler is None:
        if shuffle:
            sampler = RandomSampler(dataset)
        else:
            sampler = SequentialSampler(dataset)
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)

这里用到的是SequentialSampler类。另一个是训练用,只是加入了随机特性而已。因此实际执行的只有两行代码

sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)

对于第一行,在SequentialSampler类中,可以看到

def __init__(self, data_source):
    self.data_source = data_source

def __iter__(self):
    return iter(range(len(self.data_source)))

它返回的是一个有关输入数据索引的迭代器,因此sampler是一个关于输入数据索引的迭代器。
对于第二行,找到BatchSampler类的__iter__()方法

def __iter__(self):
    batch = []
    for idx in self.sampler:
        batch.append(idx)
        if len(batch) == self.batch_size:
            yield batch
            batch = []
    if len(batch) > 0 and not self.drop_last:
        yield batch

在这里就不得不提一下yield这个东西。它和return很像,可以返回值,但是程序并不会跳出,而是继续进行,因此它可以实现迭代器的功能,每次返回一个batch。在这个batch中,存的是一组索引。比如有100个数据,batch_size=3,那么在第二次迭代时,batch=[3,4,5]。
因此,针对第一句话可以解释如下:

indices = next(self.sample_iter)  # may raise StopIteration

self.sample_iter = iter(self.batch_sampler)是一个迭代器,它遍历每个batch对应的索引。indices = next(self.sample_iter)就是在执行迭代步骤,indices就是当前batch对应的索引。

第二句

明白第一句的作用后,再看第二句。先将其列如下

batch = self.collate_fn([self.dataset[i] for i in indices])

其中[self.dataset[i] for i in indices]很好理解,就是输入数据的某一个batch。但是这个self.collate_fn是什么鬼?马上定位到其实现部分

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        out = None
        if _use_shared_memory:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = batch[0].storage()._new_shared(numel)
            out = batch[0].new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if re.search('[SaUO]', elem.dtype.str) is not None:
                raise TypeError(error_msg.format(elem.dtype))

            return torch.stack([torch.from_numpy(b) for b in batch], 0)
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
    elif isinstance(batch[0], int_classes):
        return torch.LongTensor(batch)
    elif isinstance(batch[0], float):
        return torch.DoubleTensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    elif isinstance(batch[0], container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], container_abcs.Sequence):
        print('container_abcs.Sequence')
        transposed = zip(*batch)
        for samples in transposed:
            print(samples)
        return [default_collate(samples) for samples in transposed]

    raise TypeError((error_msg.format(type(batch[0]))))

可以发现,入参名为batch,它先判断batch[0]的数据类型。对于本文的问题,它是一个列表,属于container_abcs.Sequence,即执行以下代码

transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]

注意到这里有一个zip方法,而且在入参前面还加了一个*。对于*的作用,网上解释如下:在函数调用多个参数时,在列表、元组、集合、字典及其他可迭代对象前加*,可以使其自动解包并传递给相应单个变量,其中参数个数个元素个数要相等。
举个例子

def foo(a,b,c):
    print(a,b,c)

foo(*[1,2,3])

结果

1 2 3

zip(*batch)中,因为有*,所以batch中的每一个元素作为zip()的第1、2、3…个输入参数;如果不加*,那么整个batch作为zip()的唯一输入参数。区别如下

batch = [[1,2], [3,4]]

transposed = zip(*batch)
for samples in transposed:
    print(samples)

transposed = zip(batch)
for samples in transposed:
    print(samples)

结果

(1, 3)
(2, 4)
([1, 2],)
([3, 4],)

这便是问题的根源所在。
这里插一脚,zip()返回的东西可以看成迭代器,迭代完了就完了,如下

batch = [[1,2], [3,4]]
transposed = zip(*batch)
for samples in transposed:
    print(samples)
for samples in transposed:
    print(samples)

结果

(1, 3)
(2, 4)

下面继续分析问题,注意到在zip()函数后还有一句

return [default_collate(samples) for samples in transposed]

即把zip()后的结果再次逐元素地送回原来的default_collate()函数打包。
假设batch=[[1,1],[2,2],[3,3]],transposed = zip(*batch)的遍历结果为(1,2,3),(1,2,3),那么其中每个元素samples=(1,2,3)将再次被打包。由于其第0位的数据类型为int,所以进入

elif isinstance(batch[0], int_classes):
    return torch.LongTensor(batch)

即将其转为tensor类型。这些被转为tensor类型的元素被放入一个列表中返回,如果沿用前面的例子,那么返回结果为[tensor([ 1., 2., 3.], dtype=torch.float64), tensor([ 1., 2.,3.], dtype=torch.float64)]
也就是第二句中batch的值。

batch = self.collate_fn([self.dataset[i] for i in indices])

同样也是最初for循环中c_point的值。

for batch_idx, c_points in enumerate(data_loader):
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值