Pytorch Dataset和Dataloader 学习笔记(二)

Pytorch Dataset & Dataloader

领券 m.cps3.cn

Pytorch框架下的工具包中,提供了数据处理的两个重要接口,Dataset 和 Dataloader,能够方便的使用和按批装载自己的数据集。

  1. 数据的预处理,加载数据并转化为tensor格式

  2. 使用Dataset构建自己的数据

  3. 使用Dataloader装载数据

【数据】链接:https://pan.baidu.com/s/1gdWFuUakuslj-EKyfyQYLA
提取码:10d4
复制这段内容后打开百度网盘手机App,操作更方便哦

数据的预处理与加载
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset

## 1. 数据的处理,加载转化为tensor
x_data = 'X.csv'
y_data = 'y.csv'
x = np.loadtxt(x_data, delimiter=' ', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=' ', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :])
torch.utils.data.Dataset

Dataset抽象类,用于包装构建自己的数据集,该类包括三个基本的方法:

  • __init__ 进行数据的读取操作
  • __getitem__ 数据集需支持索引访问
  • __len__ 返回数据集的长度
## 2. 构建自己的数据集
class Mydataset(Dataset):
    def __init__(self, train_data, label_data):
        self.train = train_data
        self.label = label_data
        self.len = len(train_data)

    def __getitem__(self, item):
        return self.train[item], self.label[item]

    def __len__(self):
        return self.len

dataset = Mydataset(x, y)
samples = dataset.__len__()
print("总样本数:",samples)
torch.utils.data.Dataloader

Dataloader抽象类,构建可迭代的数据集装载器,从Dataset实例对象中按batch_size装载数据以送入训练。包含以下几个参数:

  • batch_size 批大小
  • shuffle 装载的batch是否乱序
  • drop_last 不足batch大小的最后部分是否舍去
  • num_workers 是否多进程读取数据
## 3. 创建数据集装载器
train_loader = DataLoader(dataset=dataset,
                          batch_size=64,
                          shuffle=True,
                          drop_last=True,
                          num_workers=4)
测试
if __name__ == "__main__":
    iteration = 0
    for train_data, train_label in train_loader:
        print("x: ", train_data, "
y: ", train_label)
        iteration += 1
    ### 这里dataloader中drop_last为True,所以迭代次数应为 samples/batch_size = 6
    print("每个epoch迭代次数:",iteration)

完整代码
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset

## 1. 数据的处理,加载转化为tensor
x_data = 'X.csv'
y_data = 'y.csv'
x = np.loadtxt(x_data, delimiter=' ', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=' ', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :])

## 2. 构建自己的数据集
class Mydataset(Dataset):
    def __init__(self, train_data, label_data):
        self.train = train_data
        self.label = label_data
        self.len = len(train_data)

    def __getitem__(self, item):
        return self.train[item], self.label[item]

    def __len__(self):
        return self.len

dataset = Mydataset(x, y)

## 3. 创建数据集装载器
train_loader = DataLoader(dataset=dataset,
                          batch_size=64,
                          shuffle=True,
                          drop_last=True,
                          num_workers=4)

if __name__ == "__main__":
    iteration = 0
    samples = dataset.__len__()
    print("总样本数:", samples)
    for train_data, train_label in train_loader:
        print("x: ", train_data, "
y: ", train_label)
        iteration += 1
    ### 这里dataloader中drop_last为True,所以迭代次数应为 samples/batch_size = 6
    print("每个epoch迭代次数:",iteration)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值