前言
想自己重写Dataset类,不通过torchvision.dataset.CIFAR10获取数据集。但是从官网下载的数据集是压缩包形式,直接解压无法得到图片和标签信息,因此参考博客将图片和标签读取出来。
下载数据集
首先可以通过pytorch下载CIFAR10数据集
#train.py
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
#device : GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 50000张训练图片
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
shuffle=False, num_workers=0)
# 10000张验证图片
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=4,
shuffle=False, num_workers=0)
val_data_iter = iter(val_loader)
val_image, val_label = val_data_iter.next()
print(val_image.size())
print(train_set.class_to_idx)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#显示图像,之前需把validate_loader中batch_size改为4
aaa = train_set.class_to_idx
cla_dict = dict((val, key) for key, val in aaa.items())
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
#plt.imshow(npimg)
tt = np.transpose(npimg, (1, 2, 0))
print(tt.shape)
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
print(' '.join('%5s' % cla_dict[val_label[j].item()] for j in range(4)))
imshow(utils.make_grid(val_image))
数据集可视化
通过反序列化将数据读取出来
train
readDataTrain.py
import pickle
from imageio import imsave
import numpy as np
def load_file(filename):
with open(filename, 'rb') as fo:
data = pickle.load(fo, encoding='latin1')
return data
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
dic = load_file('data/cifar-10-batches-py/batches.meta')
labels_item = dic['label_names']
for k in range(1, 6):
dic = unpickle("data/cifar-10-batches-py/data_batch_" + str(k))
dict_image_data = dic[b'data']
dict_image_labels = dic[b'labels']
len = dict_image_data.shape[0]
for i in range(len):
id = len * (k - 1) + i + 1
id = str(id).zfill(5)
imgs = dict_image_data[i]
labels = dict_image_labels[i]
imgs_array = np.reshape(imgs, (3, 32, 32))
imgs_array = imgs_array.transpose(1, 2, 0)
imsave("data/cifar10/train/imges/" + id + '.jpg', imgs_array)
with open("data/cifar10/train/labels/" + id + '.txt', 'w') as f:
f.write(str(dict_image_labels[i]))
test
readDataTest.py
import pickle
from imageio import imsave
import numpy as np
def load_file(filename):
with open(filename, 'rb') as fo:
data = pickle.load(fo, encoding='latin1')
return data
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
dic = load_file('data/cifar-10-batches-py/batches.meta')
labels_item = dic['label_names']
dic = unpickle("data/cifar-10-batches-py/test_batch")
dict_image_data = dic[b'data']
dict_image_labels = dic[b'labels']
len = dict_image_data.shape[0]
for i in range(len):
id = i + 1
id = str(id).zfill(5)
imgs = dict_image_data[i]
labels = dict_image_labels[i]
imgs_array = np.reshape(imgs, (3, 32, 32))
imgs_array = imgs_array.transpose(1, 2, 0)
imsave("data/cifar10/test/imges/" + id + '.jpg', imgs_array)
with open("data/cifar10/test/labels/" + id + '.txt', 'w') as f:
f.write(str(dict_image_labels[i]))
效果
标签是0-9之间的数字
标签的对应关系
{'airplane': 0, 'automobile': 1, 'bird': 2,
'cat': 3, 'deer': 4, 'dog': 5,
'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}