inferno.io.box.cifar介绍及使用
inferno简介
Inferno是一个库,提供了围绕PyTorch的实用程序和方便的函数/类,为深度学习和实现神经网络提供便利。关于inferno的其他模块介绍:
inferno Pytorch: inferno.extensions.layers.convolutional 介绍及使用
inferno Pytorch: inferno.io.box.cifar下载cifar10 cifar100数据集 介绍及使用
inferno Pytorch: inferno.io.transform 介绍及使用
inferno安装
pip install inferno-pytorch
inferno.io.box.cifar
inferno.io.box.cifar包含两个函数,分别用于下载cifar10
和cifar100
数据集(cifar数据集简单介绍),只需要一行代码即可下载。
源码
函数入口如下:
def get_cifar10_loaders(root_directory, train_batch_size=128, test_batch_size=256,
download=False, augment=False, validation_dataset_size=None):
def get_cifar100_loaders(root_directory, train_batch_size=128, test_batch_size=100,
download=False, augment=False, validation_dataset_size=None):
返回的是一个DataLoader.
具体源码:
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
def get_cifar10_loaders(root_directory, train_batch_size=128, test_batch_size=256,
download=False, augment=False, validation_dataset_size=None):
# Data preparation for CIFAR10.
if augment:
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)),
])
else:
transform_train = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.4914, 0.4822, 0.4465