pytorch之torch.utils.data

首先import torch.utils.data 

训练模型前的第一件事情就是导入训练时所需要用到的数据。

首先定义数据类

形式为下图代码所示,大概是这个形式。(根据自身需要用到的数据集进行调整)

import torch.utils.data
class MyDataset(torch.utils.data.Dataset):
    def __init__(self,train_data,test_data):
        super(MyDataset,self).__init__()
        self.train_data=train_data
        self.test_data=test_data

    def __getitem(self,item):
        return self.train_data(item)

    def __len__(self):
        return self.train_data.shape()
        
        

我们自定义的类要继承torch.utils.data中Dataset这个父类,在init初始化方法中采用super()这个特殊函数,super函数里必须要包含两个参数,分别是子类名和参数self,这样你的自定义数据类就可继承Dataset父类的方法。对于Dataset类,查看类的声明可以得知,必须重写 getitem 和 len 这两个函数。

定义完MyDataset后使用torch.utils.data.DataLoader。DataLoader是pytorch中读取数据的一个重要接口,基本上用pytorch训练都会用到。这个接口的目的是将自定义的Dataset根据batch size大小,是否shuffle等选项封装成一个batch size大小的tensor。

处理过程如下:

dataset=MyDataset()
loader=torch.utils.data.DataLoader(dataset,batch_size=2,shuffle=True,...)

 以下是DataLoader初始化的参数:(按住ctrl 点击类名进行查询,自行根据所需调整相应的参数即可)

 关于DataLoader更深一步的了解可跳转这篇文章:http://t.csdn.cn/MiRdH

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值