pytorch学习笔记(四):输入流水线(input pipeline)

input-pipeline

引包

from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

图像预处理

# 创建个transform用来处理图像数据
transform = transforms.Compose([
    transforms.Scale(40),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor()])

准备数据

# 下载数据
train_dataset = dsets.CIFAR10(root='./data/',
                               train=True,
                               transform=transform,#用了之前定义的transform
                               download=True)

image, label = train_dataset[0]
print (image.size())
print (label)
Files already downloaded and verified
torch.Size([3, 32, 32])
6

加载数据

# data loader提供了队列和线程
train_loader = data.DataLoader(dataset=train_dataset,
                               batch_size=100,# 这里定义了batch_size
                               shuffle=True,
                               num_workers=2)
# 迭代开始,然后,队列和线程跟着也开始
data_iter = iter(train_loader)

# mini-batch 图像 和 标签
images, labels = next(data_iter)

for images, labels in train_loader:
    # 这里是训练代码
    pass
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值