Pytorch—Dataset and DataLoader(自定义一个数据集)

Pytorch—Dataset and DataLoader(自定义一个数据集)

所谓数据集,其实就是一个负责处理索引(index)到样本(sample)映射的一个类(class)。

Dataset和DataLoader是帮助加载数据的两个工具类。

Dataset主要是构建数据集,支持索引;DataLoader主要是创建一个读取小批量数据样本的DataLoader实例。

1. 如何定义自己的数据集

Pytorch提供两种数据集:Map式数据集和Iterable数据集。

一个Map式的数据集必须要重写getitem(self,index),len(self)两个内建方法,用来表示从索引到样本的映射(Map)。

torch.utils.data.Dataset 是一个抽象类,因此不能实例化,只能被其他子类去继承,构造一个自定义类。torch.utils.data.DataLoader 可以帮助加载数据,比如shuffle、读取小批量等,可以实例化。

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class DiabetesDatset(Dataset):
    def __init__(self):
        pass
    
    def __getitem__(self,index):
        pass
    
   	def __len__(self):
        pass
    
dataset = DiabetesDataset()
train_loader = DataLoader(dataset=dataset,
                         batch_size=32,
                         shuffle=True,
                         num_workers=8)
  • 实现getitem(self,index)方法用于实例化的对象支持下标操作,能通过索引获取数据;

  • 实现len(self)方法用于返回对象的长度;

  • 实例化DataLoader时,初始化时常使用4个参数:dataset【数据集】、batch_size【批量大小】、shuffle【是否打乱】、num_workers【读取mini-batch数据时开启多线程个数】。

2. 糖尿病数据集实现

数据集连接:添加链接描述
提取码:jl6w

import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader

class DiabetesDataset(Dataset):
    def __init__(self,filepath):#数据集较小,全部加载近了数据集中
        xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:,:-1])
        self.y_data = torch.from_numpy(xy[:,[-1]])#得到矩阵

    def __getitem__(self,index):
        return self.x_data[index],self.y_data[index]

    def __len__(self):
        return self.len
    
path = "D:\\diabetes.csv.gz"
batch_size = 32
num_workers = 8
dataset = DiabetesDataset(path)
train_loader = DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)

3. 糖尿病数据集分类训练模型

共分成四个部分:

  1. 准备数据集(Dataset和DataLoader)
  2. 构造模型(继承自nn.Module)
  3. 构造损失和优化器
  4. 训练(forward、backward和update)
import numpy as np
import torch
from torch.nn import Module,Linear,Sigmoid,BCELoss
from torch.utils.data import Dataset,DataLoader
from torch.utils.tensorboard import SummaryWriter

#1. 准备数据集
class DiabetesDataset(Dataset):
    def __init__(self,filepath):
        xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:,:-1])
        self.y_data = torch.from_numpy(xy[:,[-1]])#得到矩阵

    def __getitem__(self,index):
        return self.x_data[index],self.y_data[index]

    def __len__(self):
        return self.len

path = "D:\\diabetes.csv.gz"
batch_size = 32
num_workers = 8
dataset = DiabetesDataset(path)
train_loader = DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)

#2. 构造模型
class Model(Module):
    def __init__(self):
        super(Model,self).__init__()
        self.linear1 = Linear(8,6)
        self.linear2 = Linear(6,4)
        self.linear3 = Linear(4,1)
        self.sigmoid = Sigmoid()

    def forward(self,x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x

model = Model()

#3. 构造损失和优化器
criterion = BCELoss(size_average=True)  # 求平均损失
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
epochs = 100    

#4. 训练(forward、backward和update)
if __name__ == '__main__':
    for epoch in range(epochs):
        #1. Prepare data
        for i,(inputs,labels) in enumerate(train_loader):
            #2. Forward
            y_pred = model(inputs)
            loss = criterion(y_pred,labels)
            print(epoch,i,loss.item())
            #3. Backward
            optimizer.zero_grad()
            loss.backward()
            #4. Update
            optimizer.step()
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值