小白一枚,话不多说,直接上码,亲测能跑
预处理:
import numpy as np
import tensorflow as tf
import pickle
import os
import random
def _random_crop(batch, crop_shape, padding=None):
oshape = batch[0].shape
if padding:
oshape = (oshape[0] + 2 * padding, oshape[1] + 2 * padding)
new_batch = []
npad = ((padding, padding), (padding, padding), (0, 0))
for i in range(len(batch)):
new_batch.append(batch[i])
if padding:
new_batch[i] = np.lib.pad(batch[i], pad_width=npad, mode='constant', constant_values=0)
nh = random.randint(0, oshape[0] - crop_shape[0])
nw = random.randint(0, oshape[1] - crop_shape[1])
new_batch[i] = new_batch[i][nh:nh + crop_shape[0], nw:nw + crop_shape[1]]
return new_batch
def _batch_random_flip_left_right(batch):
new_batch = []
for img in batch:
if bool(random.randint(0, 1)):
img = np.fliplr(img)
new_batch.append(img)
new_batch = np.array(new_batch)
return new_batch
def training_data(data_dir):
data_lst = os.listdir(data_dir)
img = None
labels = None
for file in data_lst:
if file in ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']:
new_path = os.path.join(data_dir, file)
with open(new_path, 'rb') as f:
dict = pickle.load(f)
if img is None:
img = dict['data']
else:
img = np.vstack((img, dict['data']))
if labels is None:
labels = dict['labels']
else:
labels += dict['labels']
cnt = 0
images = []
for image in img:
r = image[:1024]
g = image[1024:2048]
b = image[2048:]
r = np.array([r])
g = np.array([g])
b = np.array([b])
r_t = r.T
g_t = g.T
b_t = b.T
new_image = np.hstack((r_t, g_t, b_t))
new_image = new_image.reshape([32, 32, 3])
images.append(new_image)
cnt += 1
images = np.array(images)
print '========Finish loading training data========'
return images, labels
def testing_data(data_dir):
data_lst = os.listdir(data_dir)
img = None
labels = None
for file in data_lst:
if file == 'test_batch':
new_path = os.path.join(data_dir, file)
with open(new_path, 'rb') as f:
dict = pickle.load(f)
if img is None:
img = dict['data'