Cifar10数据以及加载

import numpy as np
import os
import sys
import keras.backend as K
from six.moves import cPickle
import cv2
def load_batch(fpath, label_key='labels'):
    f = open(fpath, 'rb')
    if sys.version_info < (3,):
        d = cPickle.load(f)
    else:
        d = cPickle.load(f, encoding='bytes')
        # decode utf8
        d_decoded = {}
        for k, v in d.items():
            d_decoded[k.decode('utf8')] = v
        d = d_decoded
    f.close()
    data = d['data']
    labels = d[label_key]

    data = data.reshape(data.shape[0], 3, 32, 32)
    return data, labels
def load_data():
    dirname = './cifar-10-batches-py/'
    num_train_samples = 50000

    x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')
    y_train = np.empty((num_train_samples,), dtype='uint8')
    for i in range(1, 6):
        fpath = os.path.join(dirname, 'data_batch_' + str(i))
        (x_train[(i - 1) * 10000: i * 10000, :, :, :],
         y_train[(i - 1) * 10000: i * 10000]) = load_batch(fpath)
    fpath = os.path.join(dirname, 'test_batch')
    x_test, y_test = load_batch(fpath)

    y_train = np.reshape(y_train, (len(y_train), 1))
    y_test = np.reshape(y_test, (len(y_test), 1))

    if K.image_data_format() == 'channels_last':
        x_train = x_train.transpose(0, 2, 3, 1)
        x_test = x_test.transpose(0, 2, 3, 1)

    return (x_train, y_train), (x_test, y_test)
cifar10数据地址:链接: https://pan.baidu.com/s/1feueGEEjJ1sYZCC08Q3HFA 密码: ausw
解压后即可修改load_data的dirname即可读取,里面还有mnist数据。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
加载CIFAR10数据集,可以使用PyTorch中的torchvision模块。以下是一个简单的示例代码,可以用于加载CIFAR10数据集: ``` import torch import torchvision import torchvision.transforms as transforms # 定义数据变换 transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 加载训练集 trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) # 加载测试集 testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) # 类别标签 classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') ``` 在上面的代码中,我们首先定义了一个数据变换,将图像转换为张量并进行标准化。然后,我们使用`torchvision.datasets.CIFAR10`函数加载训练集和测试集,设置了一些参数,例如`root`表示数据集存储的路径,`train`表示是否加载训练集,`download`表示是否从网上下载数据集等。接着,我们使用`torch.utils.data.DataLoader`函数将数据加载到内存中,并设置了一些参数,例如`batch_size`表示每个batch的大小,`shuffle`表示是否打乱数据集等。最后,我们定义了10个类别标签,分别对应数据集中的10个类别。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值