# Create input data pipeline.with tf.device("/cpu:0"):
train_files = glob.glob(filedir)#获取目录下所有的图片存入train_filesifnot train_files:raise RuntimeError("No training images found with glob '{}'.".format(args.train_glob))
train_dataset = tf.data.Dataset.from_tensor_slices(train_files)#进行数据装载,
train_dataset = train_dataset.shuffle(buffer_size=len(train_files)).repeat()#对所有的图片进行随机
train_dataset = train_dataset.map(
read_png, num_parallel_calls=args.preprocess_threads)#预处理,包括多线程
train_dataset = train_dataset.map(lambda x: tf.random_crop(x,(patchsize, patchsize,3)))#进行裁剪
train_dataset = train_dataset.batch(batchsize)#设置batchsize
train_dataset = train_dataset.prefetch(32)#不知道,应该是缓冲区吧
num_pixels = batchsize * patchsize **2# Get training patch from dataset.
x = train_dataset.make_one_shot_iterator().get_next()#设置迭代----------------------------------------------------------------------defread_png(filename):"""Loads a PNG image file."""
string = tf.read_file(filename)
image = tf.image.decode_image(string, channels=3)
image = tf.cast(image, tf.float32)
image /=255return image
3.有标签
defget_files(filename):
class_train =[]
label_train =[]for train_class in os.listdir(filename):for pic in os.listdir(filename+train_class):
class_train.append(filename+train_class+'/'+pic)
label_train.append(train_class)
temp = np.array([class_train,label_train])
temp = temp.transpose()#shuffle the samples
np.random.shuffle(temp)#after transpose, images is in dimension 0 and label in dimension 1
image_list =list(temp[:,0])
label_list =list(temp[:,1])
label_list =[int(i)for i in label_list]#print(label_list)return image_list,label_list