cifar10 32*32批次转227*227 适合AlexNet

import numpy as np
import cv2
import os
import cPickle
from PIL import Image
import matplotlib.pyplot as plt


CURRENT_DIR = os.getcwd()
 

def read_cifar10_train_data(dataset_file_path):
    data_dir = dataset_file_path
    train_name = 'data_batch_'
    train_X = None
    train_Y = None

    # train data
    for i in range(1,6):
        file_path = data_dir+train_name+str(i)
        with open(file_path, 'rb') as fo:
            
            dict = cPickle.load(fo)
           
            if  train_X is None:
                train_X = dict['data']
                train_Y = dict['labels']
            else:
                train_X = np.concatenate((train_X, dict['data']), axis=0)
                train_Y = np.concatenate((train_Y, dict['labels']), axis=0)
                
 
    
    train_X = train_X.reshape((50000, 3, 32, 32)).transpose(0, 2, 3, 1)

    train_y_vec = np.zeros((len(train_Y), 10), dtype=np.float)

    for i, label in enumerate(train_Y):
        train_y_vec[i, int(train_Y[i])] = 1.  # y_vec[1,3] means #2 row, #4column

 
    return train_X, train_y_vec


def read_cifar10_test_data(dataset_file_path):
    data_dir = dataset_file_path
    test_name = 'test_batch'
    test_X = None
    test_Y = None
      
    # test_data
    file_path = data_dir + test_name
    with open(file_path, 'rb') as fo:
        dict = cPickle.load(fo)
 
        test_X = dict['data']
        test_Y = dict['labels']

    test_X = test_X.reshape((10000, 3, 32, 32)).transpose(0, 2, 3, 1)

    test_y_vec = np.zeros((len(test_Y), 10), dtype=np.float)

    for i, label in enumerate(test_Y):
        test_y_vec[i, int(test_Y[i])] = 1.  # y_vec[1,3] means #2 row, #4column
 
    return test_X, test_y_vec



class BatchReadData(object):

    def __init__(self, dataset_file_path, output_size=[227, 227], train_data=True, shuffle=False):
        self.output_size = output_size
        self.shuffle = shuffle

        self.pointer = 0
        # 读数据
        if train_data:
            self.images, self.labels = read_cifar10_train_data(dataset_file_path)
        else:
            self.images, self.labels = read_cifar10_test_data(dataset_file_path)
        
        # Shuffle the data
        if self.shuffle:
            self.shuffle_data()
            
            
    def reset_pointer(self):
        self.pointer = 0
        
        if self.shuffle:
            self.shuffle_data()
            
    def shuffle_data(self):
        temp_images = self.images[:]
        temp_labels = self.labels[:]
        
        self.images = []
        self.labels = []

        idx = np.random.permutation(len(temp_labels))
        for i in idx:
            self.images.append(temp_images[i])
            self.labels.append(temp_labels[i])
            
            
    def next_batch(self, batch_size):
        # Get next batch of image (path) and labels
        paths = self.images[self.pointer:(self.pointer+batch_size)]
        labels = self.labels[self.pointer:(self.pointer+batch_size)]        
        
        print len(paths)
        print paths[0].shape
        
        # Update pointer
        self.pointer += batch_size
        
        # Read images
        images = np.ndarray([batch_size, self.output_size[0], self.output_size[1], 3])
        #images = np.zeros((batch_size, self.output_size[0], self.output_size[1], 3))
        
        for i in range(len(paths)):
            
            img = paths[i]           

            # Resize the image for output
            img = Image.fromarray(img)
            img = np.array(img.resize((227,227),Image.BICUBIC))# 修改分辨率,再转为array类
            #img = cv2.resize(img, (self.output_size[0], self.output_size[0]))     #  这上面的两种方法都可以
            images[i,:,:,:] = img

        return images/255., labels
    
 
        
# 测试代码 
dataset_file_path = CURRENT_DIR+'/data/cifar-10-batches-py/'


one = BatchReadData(dataset_file_path, [227, 227],False, False)


for i in range(10):
     
    images, labels = one.next_batch(100)
     
    fig, axarr = plt.subplots(1, 2) 
    axarr[0].imshow(images[0]) 
    axarr[1].imshow(images[1])
    print labels[0], labels[1]
    plt.show()
    if i==3:
        one.reset_pointer()

适合批次读取,不需要太多内存。

改自github

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值