1 importos, json, argparse2 from threading importThread3 from Queue importQueue4
5 importnumpy as np6 from scipy.misc importimread, imresize7 importh5py8
9 """
10 Create an HDF5 file of images for training a feedforward style transfer model.11 """
12
13 parser =argparse.ArgumentParser()14 parser.add_argument('--train_dir', default='/media/wangxiao/WangXiao_Dataset/CoCo/train2014')15 parser.add_argument('--val_dir', default='/media/wangxiao/WangXiao_Dataset/CoCo/val2014')16 parser.add_argument('--output_file', default='/media/wangxiao/WangXiao_Dataset/CoCo/coco-256.h5')17 parser.add_argument('--height', type=int, default=256)18 parser.add_argument('--width', type=int, default=256)19 parser.add_argument('--max_images', type=int, default=-1)20 parser.add_argument('--num_workers', type=int, default=2)21 parser.add_argument('--include_val', type=int, default=1)22 parser.add_argument('--max_resize', default=16, type=int)23 args =parser.parse_args()24
25
26 defadd_data(h5_file, image_dir, prefix, args):27 #Make a list of all images in the source directory
28 image_list =[]29 image_extensions = {'.jpg', '.jpeg', '.JPG', '.JPEG', '.png', '.PNG'}30 for filename inos.listdir(image_dir):31 ext = os.path.splitext(filename)[1]32 if ext inimage_extensions:33 image_list.append(os.path.join(image_dir, filename))34 num_images =len(image_list)35
36 #Resize all images and copy them into the hdf5 file
37 #We'll bravely try multithreading
38 dset_name = os.path.join(prefix, 'images')39 dset_size = (num_images, 3, args.height, args.width)40 imgs_dset =h5_file.create_dataset(dset_name, dset_size, np.uint8)41
42 #input_queue stores (idx, filename) tuples,
43 #output_queue stores (idx, resized_img) tuples
44 input_queue =Queue()45 output_queue =Queue()46
47 #Read workers pull images off disk and resize them
48 defread_worker():49 whileTrue:50 idx, filename =input_queue.get()51 img =imread(filename)52 try:53 #First crop the image so its size is a multiple of max_resize
54 H, W = img.shape[0], img.shape[1]55 H_crop = H - H %args.max_resize56 W_crop = W - W %args.max_resize57 img =img[:H_crop, :W_crop]58 img =imresize(img, (args.height, args.width))59 except(ValueError, IndexError) as e:60 printfilename61 printimg.shape, img.dtype62 printe63 input_queue.task_done()64 output_queue.put((idx, img))65
66 #Write workers write resized images to the hdf5 file
67 defwrite_worker():68 num_written =069 whileTrue:70 idx, img =output_queue.get()71 if img.ndim == 3:72 #RGB image, transpose from H x W x C to C x H x W
73 imgs_dset[idx] = img.transpose(2, 0, 1)74 elif img.ndim == 2:75 #Grayscale image; it is H x W so broadcasting to C x H x W will just copy
76 #grayscale values into all channels.
77 imgs_dset[idx] =img78 output_queue.task_done()79 num_written = num_written + 1
80 if num_written % 100 ==0:81 print 'Copied %d / %d images' %(num_written, num_images)82
83 #Start the read workers.
84 for i inxrange(args.num_workers):85 t = Thread(target=read_worker)86 t.daemon =True87 t.start()88
89 #h5py locks internally, so we can only use a single write worker =(
90 t = Thread(target=write_worker)91 t.daemon =True92 t.start()93
94 for idx, filename inenumerate(image_list):95 if args.max_images > 0 and idx >= args.max_images: break
96 input_queue.put((idx, filename))97
98 input_queue.join()99 output_queue.join()100
101
102
103 if __name__ == '__main__':104
105 with h5py.File(args.output_file, 'w') as f:106 add_data(f, args.train_dir, 'train2014', args)107
108 if args.include_val !=0:109 add_data(f, args.val_dir, 'val2014', args)