ARTS1

Pytorch dataloader


时间: 2020年5月24日

Tag: ARTS

content:

  • 1.pytorch dataloader source code
  • 2.algorithm
  • 3.thoughts

在这里插入图片描述
深度学习训练分为5大步骤:数据,模型,损失函数,优化策略,训练。本次总结是针对Pytorch模型中的数据涉及部分的框架中涉及到的类,希望可以深度理解模型训练过程的具体实现。涉及到的具体类所在的.py文件如下:

  • torch.utils.data.dataset.py
  • torch.utils.data.dataloader.py
  • torch.utils.data._utils.fetch.py
  • torch.utils.data.sampler.py
dataset

了解Pytorch数据dataloader类写法的人都知道,设计数据类的时候,首先是要实现dataset来继承torch.utils.data.dataset.py中的dataset类。并且实现__len__和__getitem__方法。这里我们可以看Dataset类的定义如下:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-nRfr56Ng-1590333057774)(dataset.png)]

这里明确指出所有的子类都需要overwrite len,此函数的功能是返回数据集的大小。这里没有说一定要实现__getitem__.。但是基本所有的实现都这个单独取数据的函数。

dataloader

接下来是将实现的dataset作为入参来构建dataloader.框架中的dataloader的定义如下,由于内容比较多,就截取了比较重要的部分

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

这里对数据集,batch_size, sampler进行注册。

	def __iter__(self):
    if self.num_workers == 0:
        return _SingleProcessDataLoaderIter(self)
    else:
        return _MultiProcessingDataLoaderIter(self)

这个函数就是在训练的时候使用 for i, data in enumerate(train_loader):逐批量选取训练集中元素的入口。

SingleProcessDataLoaderIter

接下来以单线程读取数据为例对取数据的过程进行说明。

在这里插入图片描述

主要过程就是 next(self) func中的self.next_index()和self.dataset_fetcher.fetch

其中的next_index调用的子类_BaseDataLoaderIter 的函数。我们看_BaseDataLoaderIter的定义:

在这里插入图片描述

本质还是调用的传入参数的_index_sampler函数,经过重重的关系抽离,我们发现这个函数是由dadaloader中的_index_sampler定义的。

在这里插入图片描述

接下来就是安装sampler中定义的采样数据的规则,
在这里插入图片描述
再数据集中采样index。

dataset_fetcher.fetch

获取了一个batch中所有数据的Index后就是按照index在dataset中调用__getitem__函数来取出对应的每个数据。

在这里插入图片描述

最后还是调用dataloader中的collate_fn函数将取出的每个数据整理成训练使用的的tensor的形式

在这里插入图片描述

torch.utils.data.sampler.py

这个函数里面有很多关于transform的方式,具体实参见其中的一个方式:

在这里插入图片描述

algorithm

leetcode 121: best time to buy and sell stock

好久没有刷题先来找找感觉,找个入门的题目,见笑了。

Question: Say you have an array for which the ith element is the price of a given stock on day i.

If you were only permitted to complete at most one transaction (i.e., buy one and sell one share of the stock), design an algorithm to find the maximum profit.

Note that you cannot sell a stock before you buy one

解法1: brute force

	class Solution {
	int maxProfit(int prices[]) {
		int maxprofit = 0;
    	for (int i = 0; i < prices.length - 1; i++) {
			for (int j = i + 1; j < prices.length; j++) {
			int profit = prices[j] - prices[i];
			if (profit > maxprofit){
				maxprofit = profit;
			}
    }
    return maxprofit;
}

解法二:
利用一个变量来记录遍历数组过程中遇到的最小值,每次遍历的时候都和最小值来比较。

		class Solution {
			public:
			int maxProfit(vector<int>& prices) {
    		if(prices.size() < 2) return 0;
    
    		int res = 0, buy = INT_MAX;
    		for(int i=0; i < prices.size(); ++i)
   		{
        			buy = min(buy, prices[i]);
        			res = max(res, prices[i]-buy);
    		}
    
    		return res;
    
		}
		};

思考与总结:

本人是算法小白,最近开始认真研读Pytorch框架。通过对源码的自习分析,加深了对深度学习训练过程中网络设计的理解。希望在最近的几周,完成Pytorch和caffe源码的对比学习,对于Pytorch,目前计划分为5个模块,本次是对其中第一个模块,数据部分的整理小姐。最后希望可以集百家所长,掌握更好的炼丹技巧,在炼丹的路上越走越远。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值