(pytorch进阶之路)pytorch训练代码编写技巧、DataLoader、爱因斯坦标示

边角料就随便丢在一篇里面了


自定义Dataset

重写init,len,getitem三个函数
getitem尽量从内存读(init中的),避免读磁盘
若数据太大,可以维持一个固定的内存池,偶尔从磁盘读

DateLoader

若getitem包含运算,则设置num_work>0(默认情况是0),并行读取
pin_memory设为True与non_blocking设为True结合,有时候可以加速,CPU数据移动到GPU数据是异步进行的

benchmark

torch.backends.cudnn.benchmark = True,开启benchmark加速卷积神经网络
GPU越差效果越好
卷积输入固定

Dataset & DataLoader

Dataset这个类处理单个训练样本,磁盘中读取训练数据集,做一些预处理,最终变成xy训练对。

DataLoader是对于多个样本而言的,Dataset得到单个训练样本后,通过dataloader获得训练所需的batch形式,对多个样本组合成一个batch,或者在每个周期以后对数据进行一个打乱,或者将数据固定的保存在GPU中

下面是例子,创建自己的dataset,为了方便演示,这里就没用读磁盘操作了

import os
import pandas as pd
from torchvision.io import read_image
import torch.utils.data


class DemoDataset(torch.utils.data.Dataset):
    def __init__(self, all_x, all_y, transform=None, target_transform=None):
        self.all_y = all_y
        self.all_x = all_x
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.all_y)

    def __getitem__(self, idx):
        x = self.all_x[idx]
        y = self.all_y[idx]
        if self.transform:
            # x预处理,如特殊符号,数字过滤
            x = self.transform(x)
        if self.target_transform:
            # label预处理
            y = self.target_transform(y)
        return x, y


def test_demo_dataset():
    bs, c, h, w = 10, 3, 16, 16
    train_x = torch.randn([bs, c, h, w])
    train_y = torch.randint(2, [bs, ])
    train_dataset = DemoDataset(train_x, train_y)
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=2,
                                                   shuffle=True)
    x, y = next(iter(train_dataloader))
    print(x.shape, y.shape)


if __name__ == '__main__':
    test_demo_dataset()

dataset处理完后,就需要对单一的样本组成batch,所以dataset实际上是一次处理一个样本的,一次返回一个特征与之对应的标签。
使用dataloader组合batch,和对样本进行打乱,并且使用multiprocessing加快数据检索速度

dataloader里面的几个参数,
sampler或者batch_sampler(一般不用batch_sampler)可以以自定义的方式从dataset中采样本,比如我们希望长度比较接近的几个样本放到同一个batch中,这时候我们就不能用shuffle了
collate_fn 一般对一个batch进行后处理,比如我们需要对batch进行pad,而这个pad不能预先计算出来

sampler可以参照一下RandomSampler和SequentialSampler的写法

看一眼最简单的SequentialSampler写法

class SequentialSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.

    Args:
        data_source (Dataset): dataset to sample from
    """
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

想深入了解再看看RandomSampler写法就好

collate_fn可以参照一下_utils.collate.default_collate
看看它主要部分的写法就好

def default_collate(batch):
	elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum(x.numel() for x in batch)
            storage = elem.storage()._new_shared(numel, device=elem.device)
            out = elem.new(storage).resize_(len(batch), *list(elem.size()))
        return torch.stack(batch, 0, out=out)

我们会发现其实default_collate它什么也没干,类比地
那么我们想自己写collate_fn,就传入batch再把处理后的batch返回就好

Dataloader的iter逻辑,主要看看简单单线程,实现是用_SingleProcessDataLoaderIter类,_SingleProcessDataLoaderIter类中创建一个_dataset_fetcher实例化,使用_dataset_fetcher可以从dataset中去取数据,_next_idex()去获取索引,再用fetch函数去得到数据,最终返回数据,Dataloader的iter返回_SingleProcessDataLoaderIter这个对象去迭代数据,而_SingleProcessDataLoaderIter是继承自_BaseDataLoaderIter,所以需要去搞明白_BaseDataLoaderIter类

 	# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
    # since '_BaseDataLoaderIter' references 'DataLoader'.
    def __iter__(self) -> '_BaseDataLoaderIter':
        # When using a single worker the returned iterator should be
        # created everytime to avoid reseting its state
        # However, in the case of a multiple workers iterator
        # the iterator is only created once in the lifetime of the
        # DataLoader object so that workers can be reused
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()

    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)


class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
        return data

_BaseDataLoaderIter类,实例化参数仍然是DataLoader,在_BaseDataLoaderIter类中__next__函数调用_next_data函数,通过_next_data得到data,最终返回这个data,_next_data又在子类中实现,而_next_data实现又是通过_dataset_fetcher的fetch函数来实现

class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        self._dataset = loader.dataset
        self._shared_seed = loader._get_shared_seed()
        if isinstance(self._dataset, IterDataPipe):
            shared_rng = torch.Generator()
            shared_rng.manual_seed(self._shared_seed)
            self._dataset = torch.utils.data.graph_settings.apply_shuffle_seed(self._dataset, shared_rng)
        self._dataset_kind = loader._dataset_kind
        self._IterableDataset_len_called = loader._IterableDataset_len_called
        self._auto_collation = loader._auto_collation
        self._drop_last = loader.drop_last
        self._index_sampler = loader._index_sampler
        self._num_workers = loader.num_workers
        self._prefetch_factor = loader.prefetch_factor
        # for other backends, pin_memory_device need to set. if not set
        # default behaviour is CUDA device. if pin_memory_device is selected
        # and pin_memory is not set, the default behaviour false.
        if (len(loader.pin_memory_device) == 0):
            self._pin_memory = loader.pin_memory and torch.cuda.is_available()
            self._pin_memory_device = None
        else:
            if not loader.pin_memory:
                warn_msg = ("pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
                            "please set pin_memory to true, if you need to use the device pin memory")
                warnings.warn(warn_msg)

            self._pin_memory = loader.pin_memory
            self._pin_memory_device = loader.pin_memory_device
        self._timeout = loader.timeout
        self._collate_fn = loader.collate_fn
        self._sampler_iter = iter(self._index_sampler)
        self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
        self._persistent_workers = loader.persistent_workers
        self._num_yielded = 0
        self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)

    def __iter__(self) -> '_BaseDataLoaderIter':
        return self

    def _reset(self, loader, first_iter=False):
        self._sampler_iter = iter(self._index_sampler)
        self._num_yielded = 0
        self._IterableDataset_len_called = loader._IterableDataset_len_called
        self._shared_seed = loader._get_shared_seed()
        if isinstance(self._dataset, IterDataPipe):
            shared_rng = torch.Generator()
            shared_rng.manual_seed(self._shared_seed)
            self._dataset = torch.utils.data.graph_settings.apply_shuffle_seed(self._dataset, shared_rng)

    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration

    def _next_data(self):
        raise NotImplementedError

    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                # TODO(https://github.com/pytorch/pytorch/issues/76750)
                self._reset()  # type: ignore[call-arg]
            data = self._next_data()
            self._num_yielded += 1
            if self._dataset_kind == _DatasetKind.Iterable and \
                    self._IterableDataset_len_called is not None and \
                    self._num_yielded > self._IterableDataset_len_called:
                warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                            "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                                  self._num_yielded)
                if self._num_workers > 0:
                    warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                                 "IterableDataset replica at each worker. Please see "
                                 "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
                warnings.warn(warn_msg)
            return data

    next = __next__  # Python 2 compatibility

    def __len__(self) -> int:
        return len(self._index_sampler)

    def __getstate__(self):
        # TODO: add limited pickling support for sharing an iterator
        # across multiple threads for HOGWILD.
        # Probably the best way to do this is by moving the sample pushing
        # to a separate thread and then just sharing the data queue
        # but signalling the end is tricky without a non-blocking API
        raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)

创建fetcher逻辑,通过调用create_fetcher函数来创建

class _DatasetKind(object):
    Map = 0
    Iterable = 1

    @staticmethod
    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        if kind == _DatasetKind.Map:
            return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
        else:
            return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)

最后的最后我们再来看看_IterableDatasetFetcher类就iter逻辑就完整结束了

class _IterableDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
        self.dataset_iter = iter(dataset)
        self.ended = False

    def fetch(self, possibly_batched_index):
        if self.ended:
            raise StopIteration

        if self.auto_collation:
            data = []
            for _ in possibly_batched_index:
                try:
                    data.append(next(self.dataset_iter))
                except StopIteration:
                    self.ended = True
                    break
            if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
                raise StopIteration
        else:
            data = next(self.dataset_iter)
        return self.collate_fn(data)

大体上看一眼就是把Dataset类iter一下,再用next返回一个数据,emmmmm…

Einops - 爱因斯坦标示

之所以叫爱因斯坦标示是因为对张量的标示方法是爱因斯坦提出来的,einops主要有三个api,分别是rearrange,reduce,repeat,我们来导入einops看一下怎么用就好

我们随机生成一个张量,x = [bs, ic, h, w],这时我们想将ic和h维度转置,用pytorch写法是 x.transpose(1,2),而用爱因斯坦标示可以用einops这个库的rearrange来操作

参数第一部分是原来的张量,中间有个箭头,第二部分是变换后的形状

还有变形操作,比如bs和ic可以浓缩成一个维度,我们可以用也可以用rearrange,这样的表示形式又直观又简洁

还可以用rearrange一行代码完成image2patch操作

rearrange还可以做堆叠操作,将list转为tensor,这个方便多了

import torch
import einops

bs, ic, h, w = 2, 3, 8, 8
x = torch.randn(bs, ic, h, w)
print(x.shape)
out1 = x.transpose(1, 2)
out2 = einops.rearrange(x, 'b i h w -> b h i w')
print(out1.shape)
print(out2.shape)

out1 = x.reshape([bs*ic, h, w])
out2 = einops.rearrange(x, 'b i h w -> (b i) h w')
print(out1.shape)
print(out2.shape)

out3 = einops.rearrange(out2, '(b i) h w -> b i h w', b=bs)
print(out3.shape)

# p1和p2是块的h和w, h1 * w1是patch数目
patch1 = einops.rearrange(x,
                          'b i (h1 p1) (w1 p2) -> b i (h1 w1) (p1 p2)',
                          p1=2,
                          p2=2)
# patch写成三维格式 [bs, patch_num, p_h * p_w * ic]
patch2 = einops.rearrange(x,
                          'b i (h1 p1) (w1 p2) -> b (h1 w1) (p1 p2 i)',
                          p1=2,
                          p2=2)

print(patch1.shape)
print(patch2.shape)

x_list = [x, x, x]
out1 = einops.rearrange(x_list, 'n b i h w -> n b i h w')
print(out1.shape)

第二个api是reduce,可以用它来做平均池化,当然还有其他的池化,比如min, max, sum, prod

# 平均池化
out1 = einops.reduce(x, 'b i h w -> b i h', reduction='mean')
print(out1.shape)
# 保持维度不变的池化操作
out1 = einops.reduce(x, 'b i h w -> b i h 1', reduction='sum')
print(out1.shape)
# 对h和w做最大池化
out1 = einops.reduce(x, 'b i h w -> b i', reduction='max')
print(out1.shape)

最后一个api,repeat,使用便于它扩维后复制操作,也能直接复制操作

# 扩维后复制
out1 = einops.rearrange(x, 'b i h w -> b i h w 1')
out2 = einops.repeat(out1, 'b i h w 1-> b i h w repeat_time', repeat_time=3)
print(out2.shape)

out1 = einops.repeat(x, 'b i h w -> b i (2 h) (2 w)')
print(out1.shape)
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值