说明:直接加载cifar100二进制文件到Pytorch
'''
直接加载文件到pytorch
meta
test
train
'''
import os
import cv2
import pickle
import time
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torch.autograd import Variable
import torch.utils.data as Data
from torchvision import transforms
def load_CIFAR_100(root, train=True, fine_label=True):
"""
root,文件名
train 训练数据集时取True,测试集时取False
fine_label 如果分类为100类时取True,分类为20类时取False
"""
if train:
filename = root + 'train'
else:
filename = root + 'test'
with open(filename, 'rb')as f:
datadict = pickle.load(f,encoding='bytes')
X = datadict[b'data']
if train:
# [50000, 32, 32, 3]
X = X.reshape(50000, 3, 32, 32).transpose(0,2,3,1)
else:
# [10000, 32, 32, 3]
X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1)
# fine_labels细分类,共100中类别
# coarse_labels超级类,共20中类别,每个超级类中实际包含5种fine_labels
# 如trees类中,又包含maple, oak, palm, pine, willow,5种具体的树
# 这里只取fine_labels
# Y = datadict[b'coarse_labels']+datadict[b'fine_labels']
if fine_label:
Y = datadict[b'fine_labels']
else:
Y = datadict[b'coarse_labels']
Y = np.array(Y)
return X, Y
class DealDataset(Data.Dataset):
"""
读取数据、初始化数据
"""
def __init__(self, root, train=True, fine_label=True, transform=None):
# 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
self.x, self.y = load_CIFAR_100(root, train=train, fine_label=fine_label)
self.transform = transform
self.train = train
def __getitem__(self, index):
img, target = self.x[index], int(self.y[index])
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.x)
root = r'E:\cifar-100-python' + '/'
batch_size = 20
# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
trainDataset = DealDataset(root, train=True, fine_label=True, transform=transforms.ToTensor())
testDataset = DealDataset(root, train=False, fine_label=True, transform=transforms.ToTensor())
# 训练数据和测试数据的装载
train_loader = Data.DataLoader(
dataset=trainDataset,
batch_size=batch_size, # 一个批次可以认为是一个包,每个包中含有batch_size张图片
shuffle=False,
)
test_loader = Data.DataLoader(
dataset=testDataset,
batch_size=batch_size,
shuffle=False,
)
if __name__ == '__main__':
# 这里trainDataset包含:train_labels, train_set等属性; 数据类型均为ndarray
print(f'trainDataset.y.shape:{trainDataset.y.shape}\n')
print(f'trainDataset.y.shape:{trainDataset.x.shape}\n')
# 这里train_loader包含:batch_size、dataset等属性,数据类型分别为int,DealDataset
# dataset中又包含train_labels, train_set等属性; 数据类型均为ndarray
print(f'train_loader.batch_size: {train_loader.batch_size}\n')
print(f'train_loader.dataset.y.shape: {train_loader.dataset.y.shape}\n')
print(f'train_loader.dataset.x.shape: {train_loader.dataset.x.shape}\n')
# # --可视化1,使用OpenCV----------------------------------------------
images, lables = next(iter(train_loader))
img = torchvision.utils.make_grid(images, nrow = 10)
img = img.numpy().transpose(1, 2, 0)
# OpenCV默认为BGR,这里img为RGB,因此需要对调img[:,:,::-1]
cv2.imshow('img', img[:,:,::-1])
cv2.waitKey(0)
运行结果:
显示图