深度学习 数据加载过程:

做深度学习:1 改网络模型的结构; 2 改 数据加载过程的代码

无非就是这两种,目前来说,网络模型结构的改动 还是比较清晰!

数据加载过程,主要分为以下三个步骤:

1 数据下载,

2 数据载入, 也就是图片的处理 。对应代码:

那些最常见的数据集,比如 mnist之类的,都是直接调用datasets.数据名()既可!

module_train = import_module('data.' + args.data_train.lower())    
trainset = getattr(module_train, args.data_train)(args)           
self.loader_train = MSDataLoader(                                   
                args,
                trainset,
                batch_size=args.batch_size,
                shuffle=True,
                pin_memory=not args.cpu
            )

比如这里的 trainset 就算数据载入,一般我们要改数据载入的过程。就是 改这个trainset里面的东西。

从哪里可以看出来是,真正的数据载入呢?

import torch.utils.data as data

要是导入了torch工具里面的data,并且有类继承了data.Dataset既可以知道,这里是真的数据载入的地方。

这个Dataset类,就是一个抽象类,一般继承了的类都会重写函数。

class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """
    functions: Dict[str, Callable] = {}

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

    def __getattr__(self, attribute_name):
        if attribute_name in Dataset.functions:
            function = functools.partial(Dataset.functions[attribute_name], self)
            return function
        else:
            raise AttributeError

3 数据装载:打包送给模型 其实,这个过程主要就算一个打包过程,把这个数据送入到网络模型。一般就会设置以下,batchsize ,shuffle是否打乱顺序啥的。

就比如上面的代码

另外,ArbSR里面的代码:

打包用的是这个MSDataLoader(这是自己写的一个类,来打包)。一般来说,就是用torch库自带的data.DataLoader类。

class MSDataLoader(DataLoader):
    def __init__(
        self, args, dataset, batch_size=1, shuffle=False,
        sampler=None, batch_sampler=None,
        collate_fn=_utils.collate.default_collate, pin_memory=False, drop_last=False,
        timeout=0, worker_init_fn=None):

        super(MSDataLoader, self).__init__(
            dataset, batch_size=batch_size, shuffle=shuffle,
            sampler=sampler, batch_sampler=batch_sampler,
            num_workers=args.n_threads, collate_fn=collate_fn,
            pin_memory=pin_memory, drop_last=drop_last,
            timeout=timeout, worker_init_fn=worker_init_fn)

        self.scale = args.scale

    def __iter__(self):
        return _MSDataLoaderIter(self)

然后,这个就是ArbSR里面的打包代码,这里的iter是一个迭代器,这个执行完了以后,它run的是

trainer代码里面的,for batch了:

        for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train):
            lr, hr = self.prepare(lr, hr)
            scale = hr.size(2) / lr.size(2)
            scale2 = hr.size(3) / lr.size(3)
            timer_data.hold()
            self.optimizer.zero_grad()

            # inference
            self.model.get_model().set_scale(scale, scale2)
            sr = self.model(lr)

            # loss function
            loss = self.loss(sr, hr)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值