Pytorch cifar100离线加载二进制文件

说明:直接加载cifar100二进制文件到Pytorch

 


'''
直接加载文件到pytorch
    meta
    test
    train
'''


import os
import cv2
import pickle
import time
import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torch.autograd import Variable
import torch.utils.data as Data
from torchvision import transforms



def load_CIFAR_100(root, train=True, fine_label=True):
    """
    root,文件名
    train  训练数据集时取True,测试集时取False
    fine_label  如果分类为100类时取True,分类为20类时取False

     """
    if train:
        filename = root + 'train'
    else:
        filename = root + 'test'

    with open(filename, 'rb')as f:
        datadict = pickle.load(f,encoding='bytes')

        X = datadict[b'data']



        if train:
            # [50000, 32, 32, 3]
            X = X.reshape(50000, 3, 32, 32).transpose(0,2,3,1)
        else:
            # [10000, 32, 32, 3]
            X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1)

        # fine_labels细分类,共100中类别
        # coarse_labels超级类,共20中类别,每个超级类中实际包含5种fine_labels
        # 如trees类中,又包含maple, oak, palm, pine, willow,5种具体的树
        # 这里只取fine_labels
        # Y = datadict[b'coarse_labels']+datadict[b'fine_labels']
        if fine_label:
            Y = datadict[b'fine_labels']
        else:
            Y = datadict[b'coarse_labels']

        Y = np.array(Y)
        return X, Y

class DealDataset(Data.Dataset):
    """
        读取数据、初始化数据
    """
    def __init__(self, root, train=True, fine_label=True, transform=None):

        # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
        self.x, self.y = load_CIFAR_100(root, train=train, fine_label=fine_label)

        self.transform = transform
        self.train = train

    def __getitem__(self, index):


        img, target = self.x[index], int(self.y[index])

        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self):

        return len(self.x)



root = r'E:\cifar-100-python' + '/'
batch_size = 20

# 实例化这个类,然后我们就得到了Dataset类型的数据,记下来就将这个类传给DataLoader,就可以了。
trainDataset = DealDataset(root, train=True, fine_label=True, transform=transforms.ToTensor())
testDataset = DealDataset(root, train=False, fine_label=True, transform=transforms.ToTensor())

# 训练数据和测试数据的装载
train_loader = Data.DataLoader(
    dataset=trainDataset,
    batch_size=batch_size, # 一个批次可以认为是一个包,每个包中含有batch_size张图片
    shuffle=False,
)

test_loader = Data.DataLoader(
    dataset=testDataset,
    batch_size=batch_size,
    shuffle=False,
)

if __name__ == '__main__':


    # 这里trainDataset包含:train_labels, train_set等属性;  数据类型均为ndarray
    print(f'trainDataset.y.shape:{trainDataset.y.shape}\n')
    print(f'trainDataset.y.shape:{trainDataset.x.shape}\n')


    # 这里train_loader包含:batch_size、dataset等属性,数据类型分别为int,DealDataset
    # dataset中又包含train_labels, train_set等属性;  数据类型均为ndarray
    print(f'train_loader.batch_size: {train_loader.batch_size}\n')
    print(f'train_loader.dataset.y.shape: {train_loader.dataset.y.shape}\n')
    print(f'train_loader.dataset.x.shape: {train_loader.dataset.x.shape}\n')


    # # --可视化1,使用OpenCV----------------------------------------------
    images, lables = next(iter(train_loader))
    img = torchvision.utils.make_grid(images, nrow = 10)
    img = img.numpy().transpose(1, 2, 0)
    # OpenCV默认为BGR,这里img为RGB,因此需要对调img[:,:,::-1]
    cv2.imshow('img', img[:,:,::-1])
    cv2.waitKey(0)

 

运行结果:

显示图

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值