原文地址:https://blog.csdn.net/ls20121006/article/details/70136156
可以实现图像pickle文件转为图片形式,并输出标签。
from __future__ import print_function
import matplotlib.pyplot as pyplot
import PIL.Image as Image
import pickle
import numpy as np
import random
import os,re
def unpickle(file):
# data:a 10000x3072 numpy array of uint8s. Each row of the array stores a 32x32 colour image.
# The first 1024 entries contain the red channel values, the next 1024 the green,
# and the final 1024 the blue. The image is stored in row-major order,
# so that the first 32 entries of the array are the red channel values of the first row of the image.
# labels:a list of 10000 numbers in the range 0-9.
# The number at index i indicates the label of the ith image in the array data.
fo = open(file, 'rb')
dict = pickle.load(fo,encoding = 'bytes')
train_labels = dict[b'labels']
train_array = dict[b'data']
fo.close()
return train_labels, train_array
def saveImg(save_path, class_index):
num = 0
arr = np.array([])
train_or_test = re.split('/', save_path)[-3]
if train_or_test == 'train':
labels = train_labels
source_arr = train_array
else:
labels = test_labels
source_arr = test_array
for i in range (len(labels)):
if labels[i] == class_index:
arr = np.concatenate((arr, source_arr[i]))
num = num + 1
arr = arr.reshape(num, 3, 32, 32)
# 保存图片
# 文件夹不存在则创建
isExists = os.path.exists(save_path)
# 判断结果
if not isExists:
os.makedirs(save_path)
for index in range(num):
a = arr[index]
# 得到RGB通道
r = Image.fromarray(a[0]).convert('L')
g = Image.fromarray(a[1]).convert('L')
b = Image.fromarray(a[2]).convert('L')
image = Image.merge("RGB", (r, g, b))
# 显示图片
# pyplot.imshow(image)
# pyplot.show()
image.save(save_path + str(index) + ".png", 'png')
# 创建 .txt 文件
def initTxt(train_or_test_set_path):
folders = os.listdir(train_or_test_set_path)
train_or_test = re.split('/', train_or_test_set_path)[-2]
arr = []
for folder in folders:
class_index = 0
files = os.listdir(train_or_test_set_path + folder)
for file in files:
arr.append(folder + "/" + file + ' ' + str(class_index))
# 打乱顺序,写入 .txt 文件
random.shuffle(arr)
with open(my_path + '/' + train_or_test + '.txt', mode='w') as f:
for i in arr:
f.write(i + '\n')
# 获取当前Python文件所在路径
my_path = os.getcwd()
train_cifar_10_path = my_path + "/cifar-10-batches-py/data_batch_3"
train_labels, train_array = unpickle(train_cifar_10_path)
# 测试集路径
test_cifar_10_path = my_path + "/cifar-10-batches-py/test_batch"
test_labels, test_array = unpickle(test_cifar_10_path)
saveImg(my_path + "/train/airplane/", 0)
saveImg(my_path + "/test/airplane/", 0)
initTxt(my_path + '/train/')
initTxt(my_path + '/test/')