dataset和dataloader学习笔记

获取一个batch数据的步骤

1,首先我们要确定数据集的长度n, 从dataset的__len__方法中可以得到这个值。

结果类似:n = 1000。

2,然后我们从0到n-1的范围中抽样出m个数(batch大小),由 DataLoader的 sampler和 batch_sampler参数进行抽样。

假定m=4, 拿到的结果是一个列表,类似:indices = [1,4,8,9]

3,接着我们从数据集中去取这m个数对应下标的元素,由 Dataset的 __getitem__方法实现的

拿到的结果是一个元组列表,类似:samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])]

4,最后我们将结果整理成两个张量作为输出,由 DataLoader的 collate_fn参数指定的

拿到的结果是两个张量,类似batch = (features,labels),

其中 features = torch.stack([X[1],X[4],X[8],X[9]])

labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])

代码示例

import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset


def collate_fn(batch):
    text, label = zip(*batch)
    text = list(text)
    label = list(label)
    for i in range(len(text)):
        text[i] = text[i] + '啧啧啧'
        label[i] = label[i] + 2
    return text, torch.tensor(label)

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = pd.read_csv(data, encoding='gbk')  # 进行初始化,这里是拿到所有的数据

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text, label = self.data.iloc[idx, 0], self.data.iloc[idx, 1]  # 这里是根据idx来返回数据,具体怎么返回根据自己的需求来。比如我的数据是读取的一个csv数据,我就需要用".iloc[x,y]"的索引方式来获取数据
        return text, label

datasets = MyDataset('test-data.csv')
dataloader = DataLoader(datasets, batch_size=2, shuffle=True, collate_fn=collate_fn)  # 初始化好datasets之后,就可以放到DataLoader中了,dataloader常用的参数有: dataset, batch_size, shuffle, num_workers, drop_last, collate_fn等

for i, (text, label) in enumerate(dataloader):
    print(i)
    print(text)
    print(label)

代码解读
通过获取一个batch数据的讲解,加上代码中的注释基本就能搞明白是怎么组织数据和获取数据的了,这里额外讲解一下dataloader中的collate_fn参数的作用。

这个函数其实就是对获取到的一个batch的数据进行处理,比如将数字转换为tensor格式等,当然我们也可以在这个函数中对读取到的数据做额外的操作,比如下面展示了设置和不设置collate_fn参数时打印的内容:

# test-data.csv中的数据如下:
text,labels
你好,1
你知道,1
我是谁,0
我是你,1
困困困,0
进进进,1
哈哈哈,0
不说了,0
再也不,1
你就是,0

# 不设置collate_fn参数时打印的内容
0
('困困困', '我是你')
tensor([0, 1])
1
('你就是', '我是谁')
tensor([0, 0])
2
('进进进', '哈哈哈')
tensor([1, 0])
3
('不说了', '你好')
tensor([0, 1])
4
('你知道', '再也不')
tensor([1, 1])

# 设置collate_fn参数时打印的内容,可以看到每个text的内容都被加了"啧啧啧",每个label的数字都加了2。
# 此外原始的元组存储text的方式也变成了列表存储的的方式,这些都是在collate_fn函数中进行操作的
# 当然,这些操作也可以在__getitem__获取数据的时候操作
0
['你知道啧啧啧', '困困困啧啧啧']
tensor([3, 2])
1
['我是谁啧啧啧', '进进进啧啧啧']
tensor([2, 3])
2
['你好啧啧啧', '不说了啧啧啧']
tensor([3, 2])
3
['哈哈哈啧啧啧', '我是你啧啧啧']
tensor([2, 3])
4
['再也不啧啧啧', '你就是啧啧啧']
tensor([3, 2])
  • 5
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值