读取CIFAR10数据集可视化

前言

想自己重写Dataset类,不通过torchvision.dataset.CIFAR10获取数据集。但是从官网下载的数据集是压缩包形式,直接解压无法得到图片和标签信息,因此参考博客将图片和标签读取出来。

下载数据集

首先可以通过pytorch下载CIFAR10数据集

#train.py

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np


#device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                         download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
                                           shuffle=False, num_workers=0)

# 10000张验证图片
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=4,
                                         shuffle=False, num_workers=0)
val_data_iter = iter(val_loader)
val_image, val_label = val_data_iter.next()
print(val_image.size())
print(train_set.class_to_idx)
classes = ('plane', 'car', 'bird', 'cat',
          'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


#显示图像,之前需把validate_loader中batch_size改为4
aaa = train_set.class_to_idx
cla_dict = dict((val, key) for key, val in aaa.items())
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    #plt.imshow(npimg)
    tt = np.transpose(npimg, (1, 2, 0))
    print(tt.shape)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

print(' '.join('%5s' % cla_dict[val_label[j].item()] for j in range(4)))
imshow(utils.make_grid(val_image))

数据集可视化

通过反序列化将数据读取出来

train

readDataTrain.py

import pickle
from imageio import imsave
import numpy as np


def load_file(filename):
    with open(filename, 'rb') as fo:
        data = pickle.load(fo, encoding='latin1')
    return data

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict



dic = load_file('data/cifar-10-batches-py/batches.meta')
labels_item = dic['label_names']

for k in range(1, 6):
    dic = unpickle("data/cifar-10-batches-py/data_batch_" + str(k))
    dict_image_data = dic[b'data']
    dict_image_labels = dic[b'labels']

    len = dict_image_data.shape[0]

    for i in range(len):
        id = len * (k - 1) + i + 1
        id = str(id).zfill(5)
        imgs = dict_image_data[i]
        labels = dict_image_labels[i]
        imgs_array = np.reshape(imgs, (3, 32, 32))
        imgs_array = imgs_array.transpose(1, 2, 0)
        imsave("data/cifar10/train/imges/" + id + '.jpg', imgs_array)
        with open("data/cifar10/train/labels/" + id + '.txt', 'w') as f:
            f.write(str(dict_image_labels[i]))


test

readDataTest.py

import pickle
from imageio import imsave
import numpy as np


def load_file(filename):
    with open(filename, 'rb') as fo:
        data = pickle.load(fo, encoding='latin1')
    return data

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict



dic = load_file('data/cifar-10-batches-py/batches.meta')
labels_item = dic['label_names']


dic = unpickle("data/cifar-10-batches-py/test_batch")
dict_image_data = dic[b'data']
dict_image_labels = dic[b'labels']

len = dict_image_data.shape[0]

for i in range(len):
    id = i + 1
    id = str(id).zfill(5)
    imgs = dict_image_data[i]
    labels = dict_image_labels[i]
    imgs_array = np.reshape(imgs, (3, 32, 32))
    imgs_array = imgs_array.transpose(1, 2, 0)
    imsave("data/cifar10/test/imges/" + id + '.jpg', imgs_array)
    with open("data/cifar10/test/labels/" + id + '.txt', 'w') as f:
        f.write(str(dict_image_labels[i]))


效果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
标签是0-9之间的数字
在这里插入图片描述
标签的对应关系

{'airplane': 0, 'automobile': 1, 'bird': 2, 
 'cat': 3, 'deer': 4, 'dog': 5, 
 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}

参考资料

手把手教你CIFAR数据集可视化

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

dotJunz

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值