关于Pytorch中Dataset和Dataloader的理解

使用Pytorch自定义读取数据时步骤如下:
1)创建Dataset对象
2)将Dataset对象作为参数传递到Dataloader中

详述步骤1)创建Dataset对象:
需要编写继承Dataset的类,并且覆写__getitem__和__len__方法,代码如下

class dataset(Dataset):
	def process(self):
		#对数据进行处理
		pass
	
	def __getitem__(self,index):
		pass
		
	def __len__(self):
		#返回数据的长度
		pass

(1)其中__getitem__函数的作用是根据索引index遍历数据
(2)__len__函数的作用是返回数据集的长度
(3)在创建的dataset类中可根据自己的需求对数据进行处理。可编写独立的数据处理函数,在__getitem__函数中进行调用,例如上述代码片段中的process函数;或者直接将数据处理方法写在__getitem__函数中。

详述步骤2)创建将Dataset对象作为参数传递到Dataloader中:
只需要将步骤1)创建的Dataset对象作为参数传递到Dataloader中,代码如下:

#创建对象
dataset_object = dataset()
#将dataset_object传递到Dataloader中
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
						num_workers=0, collate_fn=default_collate,
						pin_memory=False, 
						drop_last=False)

需要注意的是,Dataloader中存在一个默认的collate_fn函数,需要根据自己的需求重写collate_fn函数:
(1)该函数的作用是将数据整理成一个batch,即根据batch_size的大小一次性在数据集中取出batch_size个数据。例如数据集中有4条数据,batch_size的值为2,则每次在4条数据中取出2条数据。
(2)collate_fn函数的输入是一个list,list中的每个元素为自己编写的dataset类中__getitem__函数的返回值。
(3)Dataloader中drop_last代表将不足一个batch_size的数据是否保留,即假如有4条数据,batch_size的值为3,将取出一个batch_size之后剩余的1条数据是否仍然作为训练数据。

  • 6
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值