import numpy as np
import cv2
import os
import cPickle
from PIL import Image
import matplotlib.pyplot as plt
CURRENT_DIR = os.getcwd()
def read_cifar10_train_data(dataset_file_path):
data_dir = dataset_file_path
train_name = 'data_batch_'
train_X = None
train_Y = None
# train data
for i in range(1,6):
file_path = data_dir+train_name+str(i)
with open(file_path, 'rb') as fo:
dict = cPickle.load(fo)
if train_X is None:
train_X = dict['data']
train_Y = dict['labels']
else:
train_X = np.concatenate((train_X, dict['data']), axis=0)
train_Y = np.concatenate((train_Y, dict['labels']), axis=0)
train_X = train_X.reshape((50000, 3, 32, 32)).transpose(0, 2, 3, 1)
train_y_vec = np.zeros((len(train_Y), 10), dtype=np.float)
for i, label in enumerate(train_Y):
train_y_vec[i, int(train_Y[i])] = 1. # y_vec[1,3] means #2 row, #4column
return train_X, train_y_vec
def read_cifar10_test_data(dataset_file_path):
data_dir = dataset_file_path
test_name = 'test_batch'
test_X = None
test_Y = None
# test_data
file_path = data_dir + test_name
with open(file_path, 'rb') as fo:
dict = cPickle.load(fo)
test_X = dict['data']
test_Y = dict['labels']
test_X = test_X.reshape((10000, 3, 32, 32)).transpose(0, 2, 3, 1)
test_y_vec = np.zeros((len(test_Y), 10), dtype=np.float)
for i, label in enumerate(test_Y):
test_y_vec[i, int(test_Y[i])] = 1. # y_vec[1,3] means #2 row, #4column
return test_X, test_y_vec
class BatchReadData(object):
def __init__(self, dataset_file_path, output_size=[227, 227], train_data=True, shuffle=False):
self.output_size = output_size
self.shuffle = shuffle
self.pointer = 0
# 读数据
if train_data:
self.images, self.labels = read_cifar10_train_data(dataset_file_path)
else:
self.images, self.labels = read_cifar10_test_data(dataset_file_path)
# Shuffle the data
if self.shuffle:
self.shuffle_data()
def reset_pointer(self):
self.pointer = 0
if self.shuffle:
self.shuffle_data()
def shuffle_data(self):
temp_images = self.images[:]
temp_labels = self.labels[:]
self.images = []
self.labels = []
idx = np.random.permutation(len(temp_labels))
for i in idx:
self.images.append(temp_images[i])
self.labels.append(temp_labels[i])
def next_batch(self, batch_size):
# Get next batch of image (path) and labels
paths = self.images[self.pointer:(self.pointer+batch_size)]
labels = self.labels[self.pointer:(self.pointer+batch_size)]
print len(paths)
print paths[0].shape
# Update pointer
self.pointer += batch_size
# Read images
images = np.ndarray([batch_size, self.output_size[0], self.output_size[1], 3])
#images = np.zeros((batch_size, self.output_size[0], self.output_size[1], 3))
for i in range(len(paths)):
img = paths[i]
# Resize the image for output
img = Image.fromarray(img)
img = np.array(img.resize((227,227),Image.BICUBIC))# 修改分辨率,再转为array类
#img = cv2.resize(img, (self.output_size[0], self.output_size[0])) # 这上面的两种方法都可以
images[i,:,:,:] = img
return images/255., labels
# 测试代码
dataset_file_path = CURRENT_DIR+'/data/cifar-10-batches-py/'
one = BatchReadData(dataset_file_path, [227, 227],False, False)
for i in range(10):
images, labels = one.next_batch(100)
fig, axarr = plt.subplots(1, 2)
axarr[0].imshow(images[0])
axarr[1].imshow(images[1])
print labels[0], labels[1]
plt.show()
if i==3:
one.reset_pointer()
适合批次读取,不需要太多内存。
改自github