pytorch中的Dataset和Dataloader以及collate_fn参数的作用

 先定义x和y作为数据及其标签,定义我们自己的TestDataset类,在这个类中要实现__getitem__和__len__方法,才可以在Dataloader中使用,注意看Dataloader中输出的格式。

import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset

class TestDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, item):
        return self.x[item], self.y[item]

    def __len__(self):
        return len(self.x)


if __name__ == '__main__':

    BATCH_SIZE = 5      # 批训练的数据个数

    x = torch.linspace(1, 10, 10)       # x data (torch tensor)
    y = torch.linspace(10, 1, 10)       # y data (torch tensor)

    print(x)
    print(y)
    print('--------------------')

    myDataset = TestDataset(x, y)
    for step, (batch_x, batch_y) in enumerate(myDataset):
        print('step:{},x:{},y:{}'.format(step, batch_x, batch_y))

    print('-----------------')

    loader = torch.utils.data.DataLoader(
        dataset=myDataset,      # torch TensorDataset format
        batch_size=BATCH_SIZE,      # mini batch size
        shuffle=True,               
        num_workers=2,              
    )

    for epoch in range(3):   
        for step, (batch_x, batch_y) in enumerate(loader):  
            print('epoch:{},step:{},x:{},y:{}'.format(epoch, step, batch_x, batch_y))


# 输出的结果如下
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
tensor([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.])
--------------------
step:0,x:1.0,y:10.0
step:1,x:2.0,y:9.0
step:2,x:3.0,y:8.0
step:3,x:4.0,y:7.0
step:4,x:5.0,y:6.0
step:5,x:6.0,y:5.0
step:6,x:7.0,y:4.0
step:7,x:8.0,y:3.0
step:8,x:9.0,y:2.0
step:9,x:10.0,y:1.0
-----------------
epoch:0,step:0,x:tensor([5., 1., 4., 8., 7.]),y:tensor([ 6., 10.,  7.,  3.,  4.])
epoch:0,step:1,x:tensor([ 3., 10.,  9.,  6.,  2.]),y:tensor([8., 1., 2., 5., 9.])
epoch:1,step:0,x:tensor([ 4.,  7., 10.,  1.,  5.]),y:tensor([ 7.,  4.,  1., 10.,  6.])
epoch:1,step:1,x:tensor([8., 9., 6., 3., 2.]),y:tensor([3., 2., 5., 8., 9.])
epoch:2,step:0,x:tensor([2., 7., 3., 1., 4.]),y:tensor([ 9.,  4.,  8., 10.,  7.])
epoch:2,step:1,x:tensor([ 8.,  9.,  6.,  5., 10.]),y:tensor([3., 2., 5., 6., 1.])

在pytorch官网中,找到collate_fn参数相关的页面。地址为https://pytorch.org/docs/stable/data.html?highlight=dataloader#module-torch.utils.data

 

个人理解为indices为当前这一批样本的索引,比如为[1,3,5,7,9],再用collate_fn函数对这样dataset进行处理

下面是我的测试代码:

import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
def test(x):
    return x

class TestDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, item):
        return self.x[item], self.y[item]

    def __len__(self):
        return len(self.x)

    def collate_fn(self, batch):
        return batch


if __name__ == '__main__':

    BATCH_SIZE = 5      # 批训练的数据个数

    x = torch.linspace(1, 10, 10)       # x data (torch tensor)
    y = torch.linspace(10, 1, 10)       # y data (torch tensor)

    myDataset = TestDataset(x, y)

    print(test([myDataset[index] for index in [1,3,5,7,9]]))

    print('----------------')


    loader2 = torch.utils.data.DataLoader(
        dataset=myDataset,  # torch TensorDataset format
        batch_size=BATCH_SIZE,  # mini batch size
        shuffle=True,  # 要不要打乱数据 (打乱比较好)
        num_workers=2,  # 多线程来读数据
        collate_fn=myDataset.collate_fn
    )

    for epoch in range(3):  # 训练所有!整套!数据 3 次
        for step, t in enumerate(loader2):  # 每一步 loader 释放一小批数据用来学习
            # 假设这里就是你训练的地方...

            # 打出来一些数据
            print('epoch:{},step:{},data:{}'.format(epoch, step, t))

    print('----------------')


下面是结果:
[(tensor(2.), tensor(9.)), (tensor(4.), tensor(7.)), (tensor(6.), tensor(5.)), (tensor(8.), tensor(3.)), (tensor(10.), tensor(1.))]
----------------
epoch:0,step:0,data:[(tensor(8.), tensor(3.)), (tensor(9.), tensor(2.)), (tensor(6.), tensor(5.)), (tensor(5.), tensor(6.)), (tensor(2.), tensor(9.))]
epoch:0,step:1,data:[(tensor(4.), tensor(7.)), (tensor(3.), tensor(8.)), (tensor(7.), tensor(4.)), (tensor(1.), tensor(10.)), (tensor(10.), tensor(1.))]
epoch:1,step:0,data:[(tensor(10.), tensor(1.)), (tensor(7.), tensor(4.)), (tensor(1.), tensor(10.)), (tensor(6.), tensor(5.)), (tensor(4.), tensor(7.))]
epoch:1,step:1,data:[(tensor(5.), tensor(6.)), (tensor(8.), tensor(3.)), (tensor(2.), tensor(9.)), (tensor(9.), tensor(2.)), (tensor(3.), tensor(8.))]
epoch:2,step:0,data:[(tensor(6.), tensor(5.)), (tensor(5.), tensor(6.)), (tensor(7.), tensor(4.)), (tensor(8.), tensor(3.)), (tensor(10.), tensor(1.))]
epoch:2,step:1,data:[(tensor(2.), tensor(9.)), (tensor(4.), tensor(7.)), (tensor(3.), tensor(8.)), (tensor(9.), tensor(2.)), (tensor(1.), tensor(10.))]
----------------    

可以看出我的一次测试结果和使用了collate_fn函数的形式是一样的。如果有理解错误的地方,欢迎指教。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值