Pytorch学习基础——torch.DataLoader

Pytorch官方文档:API for pytorch

DataLoader函数功能:

  1. 生成数据集的可迭代对象;
  2. 利用多线程加速batch data处理;
  3. 简洁、高效、直观的用于网络输入的数据结构,使用灵活,便于扩展

DataLoader类位于torch.utils.data包下,官方API介绍如下:

常用参数说明:

  • dataset(Dataset):输入数据集
  • batch_size(int, optional): 每个batch送入多少数据集
  • shuffle(bool, optional): 是否进行重新排列

实例:加载MNIS数据集并转化为dataloader格式:

import torch
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms

#define hyperparameter
EPOCH = 1
BATCH_SIZE = 64
TIME_STEP = 28    #time_step / image_height
INPUT_SIZE = 28    #input_step / image_width
LR = 0.01
DOWNLOAD = True

#get the mnist dataset
train_data = dsets.MNIST(root='./', train=True, transform= torchvision.transforms.ToTensor(), download=DOWNLOAD)
#use dataloader to batch input dateset
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
#......#

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值