CIFAR数据集地址:http://www.cs.toronto.edu/~kriz/cifar.html
官方下载后的数据集文件是以byte形式存储的图像文件,如果我们想要可视化图片,则需要自行写一个脚本。
以下以CIFAR-100的测试集为例,它的文件名为test
,从该文件中提取出10000张
32
×
32
32\times 32
32×32的图片,并保存每张图片的label到img_label.txt
。
import pickle as p
import numpy as np
from PIL import Image
def load_CIFAR_batch(filename):
""" load single batch of cifar """
with open(filename, 'rb')as f:
datadict = p.load(f, encoding='bytes')
# 以字典的形式取出数据
X = datadict[b'data']
Y = datadict[b'fine_labels']
X = X.reshape(10000, 3, 32, 32)
Y = np.array(Y)
print(Y.shape)
return X, Y
if __name__ == "__main__":
imgX, imgY = load_CIFAR_batch("./data/cifar-100-python/test")
with open('img_label.txt', 'a+') as f:
for i in range(imgY.shape[0]):
f.write('img'+str(i)+' '+str(imgY[i])+'\n')
for i in range(imgX.shape[0]):
imgs = imgX[i]
img0 = imgs[0]
img1 = imgs[1]
img2 = imgs[2]
i0 = Image.fromarray(img0)
i1 = Image.fromarray(img1)
i2 = Image.fromarray(img2)
img = Image.merge("RGB",(i0,i1,i2))
name = "img" + str(i)+".png"
img.save("./pic1/"+name,"png")
print("save successfully!")
生成的测试集图片:
img_label.txt的信息:
img0 49
img1 33
img2 72
...
要对应label的具体类别信息,其label的序号与类别名称的字典序是一致的,类别信息有如下:
beaver, dolphin, otter, seal, whale
aquarium fish, flatfish, ray, shark, trout
orchids, poppies, roses, sunflowers, tulips
bottles, bowls, cans, cups, plates
apples, mushrooms, oranges, pears, sweet peppers
clock, computer keyboard, lamp, telephone, television
bed, chair, couch, table, wardrobe
bee, beetle, butterfly, caterpillar, cockroach
bear, leopard, lion, tiger, wolf
bridge, castle, house, road, skyscraper
cloud, forest, mountain, plain, sea
camel, cattle, chimpanzee, elephant, kangaroo
fox, porcupine, possum, raccoon, skunk
crab, lobster, snail, spider, worm
baby, boy, girl, man, woman
crocodile, dinosaur, lizard, snake, turtle
hamster, mouse, rabbit, shrew, squirrel
maple, oak, palm, pine, willow
bicycle, bus, motorcycle, pickup truck, train
lawn-mower, rocket, streetcar, tank, tractor
将该信息保存到class.txt
,对类别名称进行排序并对应label
class_ = []
with open('class.txt', 'r') as f:
for line in f.readlines():
sub_class = line.strip().split(',')
print(sub_class)
for cl in sub_class:
class_.append(cl.strip())
class_.sort()
label_to_class = {}
for label, c in enumerate(class_):
label_to_class[label] = c
with open('label_class.txt', 'w') as f:
for label, c in enumerate(class_):
label_to_class[label] = c
f.write(str(label)+' '+str(c)+'\n')
得到label_class.txt
:
0 apples
1 aquarium fish
2 baby
...
遍历CIFAR测试集,对输出对应每张图片的label和class
import PIL.Image as Image
import torch
from torchvision import transforms
import os
imgs_to_label = {}
with open('/home/ws/winycg/dataset/CIFAR-100-test-png/img_label.txt', 'r') as f:
for line in f.readlines():
img, label = line.split()[0], line.split()[1]
imgs_to_label[img] = int(label)
label_to_class = {}
with open('label_class.txt', 'r') as f:
for line in f.readlines():
label, class_ = line.split()[0], line.split()[1]
label_to_class[int(label)] = class_
test_root = 'pic1'
def transform_png(png_name):
raw_image = Image.open(os.path.join(test_root, png_name))
clsnet_image = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])(raw_image)
clsnet_image = torch.unsqueeze(clsnet_image, 0).cpu()
return clsnet_image
for img in imgs_to_label.keys():
input = transform_png(img+'.png')
label = imgs_to_label[img]
classx = label_to_class[label]