Dataset和DataLoader

我们一般使用for循环来训练神经网络,在每次的迭代过程,从DataLoader中取出batchsize的数据,然后前向传播反向传播一次,更新参数一次
在加载batch数据的时候,torch创建一个可迭代的Dataset对象(需要重写__getitem__()和__len__()两个方法),然后与DataLoader一起使用;

DataLoader: 构造一个整数索引的采样器来获取Dataset的数据

Dataset

创建Dataset对象:
需要重写 getitem 方法和 len 方法。
前者通过提供索引返回数据,也就是提供 DataLoader获取数据的方式;后者返回数据集的长度,DataLoader依据 len 确定自身索引采样器的长度。

from torch.utils.data import Dataset
# 输入形式:[{'x':['token_id',..],'y':[label]},..],[(['token_id',..],[label]),..]
class SampleDataset(Dataset):
	def __init__(self,data_pair): #[([x],[y]),..]
		self.x = [i[0] for i in data_pair] 
		# [[x],..] shape:(len,seq)
		self.y = [i[1] for i in data_pair] 
		# [[y],..]
		self.len = len(data_pair)
	def __getitem__(self,index):
		return self.x[index],self.y[index]
		# 注意:返回时self.x和self.y长度要相同
	def __len__(self):
		return self.len

DataLoader具体步骤:

1.DataLoader根据提供的Dataset对象,生成一个数据集大小的采样器
2.根据是否设置shuffle选择顺序采样器还是随机采样器
3.采样器根据数据集大小生成一个可迭代的序号列表[0,n-1]
4.batch_sampler根据DataLoader的batch_size将采样器提供的序列列表划分成多个batchsize大小的可迭代列表

# dataloader需接收tensor
from torch.utils.data import DataLoader
dataset = SampleDataset(train_data) #获取可迭代对象
train_data = Dataloader(dataset,batch_size=config.batch_size,shuffle=True,collate_fn=)                               

collate_fn()

首先dataloader将dataset切成一个个batch,然后调用collate_fn来整理数据

对batch排序
对x进行padding保持长度一致,否则报错(当y是文本也需要)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

石头猿rock

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值