inferno Pytorch: inferno.io.box.cifar下载cifar10 cifar100数据集 介绍及使用

inferno简介

Inferno是一个库,提供了围绕PyTorch的实用程序和方便的函数/类,为深度学习和实现神经网络提供便利。关于inferno的其他模块介绍:
inferno Pytorch: inferno.extensions.layers.convolutional 介绍及使用
inferno Pytorch: inferno.io.box.cifar下载cifar10 cifar100数据集 介绍及使用
inferno Pytorch: inferno.io.transform 介绍及使用

inferno安装

pip install inferno-pytorch

inferno.io.box.cifar

inferno.io.box.cifar包含两个函数,分别用于下载cifar10cifar100数据集(cifar数据集简单介绍),只需要一行代码即可下载。

源码

函数入口如下:

def get_cifar10_loaders(root_directory, train_batch_size=128, test_batch_size=256,
                        download=False, augment=False, validation_dataset_size=None):

def get_cifar100_loaders(root_directory, train_batch_size=128, test_batch_size=100,
                         download=False, augment=False, validation_dataset_size=None):

返回的是一个DataLoader.
具体源码:

import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler


def get_cifar10_loaders(root_directory, train_batch_size=128, test_batch_size=256,
                        download=False, augment=False, validation_dataset_size=None):
    # Data preparation for CIFAR10.
    if augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)),
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.4914, 0.4822, 0.4465
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值