import numpy as np
import os
from scipy.misc import imread, imresize
import matplotlib.pyplot as plt
import pprint
cwd = os.getcwd()
#################################################################
def load_img_from_folder(path):
valid_exts = ['.jpg', '.png', '.jpeg', 'gif']
imgs = []
names = []
for f in os.listdir(path):
ext = os.path.splitext(f)[1]
if ext.lower() in valid_exts:
fullpath = path + '/' + f
imgs.append(fullpath)
names.append(f)
return imgs, names
def rgb2gray(img):
if len(img.shape)== 3:
return img.dot([299./1000, 587./1000 , 114./1000])
else:
return img
#################################################################
paths = {"images/cats", "images/dogs"}
# The reshape size
imgsize = [112, 112]
# Grayscale
use_gray = 0
# Save name
data_name = "data4vgg"
nclass = len(paths)
#################################################################
dataset = []
label = []
for i, pa in zip(range(nclass), paths):
images, _ = load_img_from_folder(pa)
dataset += images
label.append(np.tile(np.eye(nclass, nclass)[i], (len(images), 1)))
label = np.vstack((label))
imgcnt = len(dataset)
# pprint.pprint(dataset)
# pprint.pprint(label)
if use_gray:
totalimg = np.ndarray((imgcnt, imgsize[0]*imgsize[1]))
totallabel = np.ndarray((imgcnt, nclass))
else:
totalimg = np.ndarray((imgcnt, imgsize[0]*imgsize[1]*3))
totallabel = np.ndarray((imgcnt, nclass))
#################################################################
for img_path, img_label, i in zip(dataset, label, range(imgcnt)):
cur_img = imread(img_path)
if use_gray:
grayimg = rgb2gray(cur_img)
else:
grayimg = cur_img
graysmall = imresize(grayimg, [imgsize[0], imgsize[1]])/255.
grayvec = np.reshape(graysmall, (1, -1))
totalimg[i, :] = grayvec
totallabel[i, :] = img_label
#################################################################
randidx = np.random.randint(imgcnt, size=imgcnt)
trainidx = randidx[0:int(4*imgcnt/5)]
testidx = randidx[int(4*imgcnt/5):imgcnt]
trainimg = totalimg[trainidx, :]
trainlabel = totallabel[trainidx, :]
testimg = totalimg[testidx, :]
testlabel = totallabel[testidx, :]
pprint.pprint(trainlabel)
reference:
https://github.com/sjchoi86/tensorflow-101/blob/master/notebooks/basic_gendataset.ipynb